mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into mkdocs-updates
This commit is contained in:
commit
555f21cd25
@ -5,8 +5,7 @@ SAMPLES_DIR=${OUT_DIR}
|
|||||||
python scripts/dream.py \
|
python scripts/dream.py \
|
||||||
--from_file ${PROMPT_FILE} \
|
--from_file ${PROMPT_FILE} \
|
||||||
--outdir ${OUT_DIR} \
|
--outdir ${OUT_DIR} \
|
||||||
--sampler plms \
|
--sampler plms
|
||||||
--full_precision
|
|
||||||
|
|
||||||
# original output by CompVis/stable-diffusion
|
# original output by CompVis/stable-diffusion
|
||||||
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"
|
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"
|
||||||
|
4
.github/workflows/test-dream-conda.yml
vendored
4
.github/workflows/test-dream-conda.yml
vendored
@ -85,9 +85,9 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
# Utterly hacky, but I don't know how else to do this
|
# Utterly hacky, but I don't know how else to do this
|
||||||
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
|
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
|
||||||
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision
|
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt
|
||||||
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
|
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
|
||||||
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision
|
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt
|
||||||
fi
|
fi
|
||||||
mkdir -p outputs/img-samples
|
mkdir -p outputs/img-samples
|
||||||
- name: Archive results
|
- name: Archive results
|
||||||
|
18
README.md
18
README.md
@ -86,17 +86,14 @@ You wil need one of the following:
|
|||||||
|
|
||||||
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
|
||||||
|
|
||||||
> Note
|
#### Note
|
||||||
>
|
|
||||||
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
|
|
||||||
> full-precision mode as shown below.
|
|
||||||
|
|
||||||
Similarly, specify full-precision mode on Apple M1 hardware.
|
Precision is auto configured based on the device. If however you encounter
|
||||||
|
errors like 'expected type Float but found Half' or 'not implemented for Half'
|
||||||
To run in full-precision mode, start `dream.py` with the `--full_precision` flag:
|
you can try starting `dream.py` with the `--precision=float32` flag:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
|
(ldm) ~/stable-diffusion$ python scripts/dream.py --precision=float32
|
||||||
```
|
```
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
@ -125,6 +122,11 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
|
|||||||
|
|
||||||
### Latest Changes
|
### Latest Changes
|
||||||
|
|
||||||
|
- vNEXT (TODO 2022)
|
||||||
|
|
||||||
|
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
|
||||||
|
configure. To switch away from auto use the new flag like `--precision=float32`.
|
||||||
|
|
||||||
- v1.14 (11 September 2022)
|
- v1.14 (11 September 2022)
|
||||||
|
|
||||||
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.
|
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.
|
||||||
|
@ -2,14 +2,14 @@ from modules.parse_seed_weights import parse_seed_weights
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
SAMPLER_CHOICES = [
|
SAMPLER_CHOICES = [
|
||||||
'ddim',
|
"ddim",
|
||||||
'k_dpm_2_a',
|
"k_dpm_2_a",
|
||||||
'k_dpm_2',
|
"k_dpm_2",
|
||||||
'k_euler_a',
|
"k_euler_a",
|
||||||
'k_euler',
|
"k_euler",
|
||||||
'k_heun',
|
"k_heun",
|
||||||
'k_lms',
|
"k_lms",
|
||||||
'plms',
|
"plms",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -20,194 +20,42 @@ def parameters_to_command(params):
|
|||||||
|
|
||||||
switches = list()
|
switches = list()
|
||||||
|
|
||||||
if 'prompt' in params:
|
if "prompt" in params:
|
||||||
switches.append(f'"{params["prompt"]}"')
|
switches.append(f'"{params["prompt"]}"')
|
||||||
if 'steps' in params:
|
if "steps" in params:
|
||||||
switches.append(f'-s {params["steps"]}')
|
switches.append(f'-s {params["steps"]}')
|
||||||
if 'seed' in params:
|
if "seed" in params:
|
||||||
switches.append(f'-S {params["seed"]}')
|
switches.append(f'-S {params["seed"]}')
|
||||||
if 'width' in params:
|
if "width" in params:
|
||||||
switches.append(f'-W {params["width"]}')
|
switches.append(f'-W {params["width"]}')
|
||||||
if 'height' in params:
|
if "height" in params:
|
||||||
switches.append(f'-H {params["height"]}')
|
switches.append(f'-H {params["height"]}')
|
||||||
if 'cfg_scale' in params:
|
if "cfg_scale" in params:
|
||||||
switches.append(f'-C {params["cfg_scale"]}')
|
switches.append(f'-C {params["cfg_scale"]}')
|
||||||
if 'sampler_name' in params:
|
if "sampler_name" in params:
|
||||||
switches.append(f'-A {params["sampler_name"]}')
|
switches.append(f'-A {params["sampler_name"]}')
|
||||||
if 'seamless' in params and params["seamless"] == True:
|
if "seamless" in params and params["seamless"] == True:
|
||||||
switches.append(f'--seamless')
|
switches.append(f"--seamless")
|
||||||
if 'init_img' in params and len(params['init_img']) > 0:
|
if "init_img" in params and len(params["init_img"]) > 0:
|
||||||
switches.append(f'-I {params["init_img"]}')
|
switches.append(f'-I {params["init_img"]}')
|
||||||
if 'init_mask' in params and len(params['init_mask']) > 0:
|
if "init_mask" in params and len(params["init_mask"]) > 0:
|
||||||
switches.append(f'-M {params["init_mask"]}')
|
switches.append(f'-M {params["init_mask"]}')
|
||||||
if 'init_color' in params and len(params['init_color']) > 0:
|
if "init_color" in params and len(params["init_color"]) > 0:
|
||||||
switches.append(f'--init_color {params["init_color"]}')
|
switches.append(f'--init_color {params["init_color"]}')
|
||||||
if 'strength' in params and 'init_img' in params:
|
if "strength" in params and "init_img" in params:
|
||||||
switches.append(f'-f {params["strength"]}')
|
switches.append(f'-f {params["strength"]}')
|
||||||
if 'fit' in params and params["fit"] == True:
|
if "fit" in params and params["fit"] == True:
|
||||||
switches.append(f'--fit')
|
switches.append(f"--fit")
|
||||||
if 'gfpgan_strength' in params and params["gfpgan_strength"]:
|
if "gfpgan_strength" in params and params["gfpgan_strength"]:
|
||||||
switches.append(f'-G {params["gfpgan_strength"]}')
|
switches.append(f'-G {params["gfpgan_strength"]}')
|
||||||
if 'upscale' in params and params["upscale"]:
|
if "upscale" in params and params["upscale"]:
|
||||||
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
|
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
|
||||||
if 'variation_amount' in params and params['variation_amount'] > 0:
|
if "variation_amount" in params and params["variation_amount"] > 0:
|
||||||
switches.append(f'-v {params["variation_amount"]}')
|
switches.append(f'-v {params["variation_amount"]}')
|
||||||
if 'with_variations' in params:
|
if "with_variations" in params:
|
||||||
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"])
|
seed_weight_pairs = ",".join(
|
||||||
switches.append(f'-V {seed_weight_pairs}')
|
f"{seed}:{weight}" for seed, weight in params["with_variations"]
|
||||||
|
)
|
||||||
|
switches.append(f"-V {seed_weight_pairs}")
|
||||||
|
|
||||||
return ' '.join(switches)
|
return " ".join(switches)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_cmd_parser():
|
|
||||||
"""
|
|
||||||
This is simply a copy of the parser from `dream.py` with a change to give
|
|
||||||
prompt a default value. This is a temporary hack pending merge of #587 which
|
|
||||||
provides a better way to do this.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
|
|
||||||
exit_on_error=True,
|
|
||||||
)
|
|
||||||
parser.add_argument('prompt', nargs='?', default='')
|
|
||||||
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
|
|
||||||
parser.add_argument(
|
|
||||||
'-S',
|
|
||||||
'--seed',
|
|
||||||
type=int,
|
|
||||||
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-n',
|
|
||||||
'--iterations',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-W', '--width', type=int, help='Image width, multiple of 64'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-H', '--height', type=int, help='Image height, multiple of 64'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-C',
|
|
||||||
'--cfg_scale',
|
|
||||||
default=7.5,
|
|
||||||
type=float,
|
|
||||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-g', '--grid', action='store_true', help='generate a grid'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--outdir',
|
|
||||||
'-o',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Directory to save generated images and a log of prompts and seeds',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--seamless',
|
|
||||||
action='store_true',
|
|
||||||
help='Change the model to seamless tiling (circular) mode',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-i',
|
|
||||||
'--individual',
|
|
||||||
action='store_true',
|
|
||||||
help='Generate individual files (default)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-I',
|
|
||||||
'--init_img',
|
|
||||||
type=str,
|
|
||||||
help='Path to input image for img2img mode (supersedes width and height)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-M',
|
|
||||||
'--init_mask',
|
|
||||||
type=str,
|
|
||||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--init_color',
|
|
||||||
type=str,
|
|
||||||
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-T',
|
|
||||||
'-fit',
|
|
||||||
'--fit',
|
|
||||||
action='store_true',
|
|
||||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-f',
|
|
||||||
'--strength',
|
|
||||||
default=0.75,
|
|
||||||
type=float,
|
|
||||||
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-G',
|
|
||||||
'--gfpgan_strength',
|
|
||||||
default=0,
|
|
||||||
type=float,
|
|
||||||
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-U',
|
|
||||||
'--upscale',
|
|
||||||
nargs='+',
|
|
||||||
default=None,
|
|
||||||
type=float,
|
|
||||||
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-save_orig',
|
|
||||||
'--save_original',
|
|
||||||
action='store_true',
|
|
||||||
help='Save original. Use it when upscaling to save both versions.',
|
|
||||||
)
|
|
||||||
# variants is going to be superseded by a generalized "prompt-morph" function
|
|
||||||
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
|
|
||||||
parser.add_argument(
|
|
||||||
'-x',
|
|
||||||
'--skip_normalize',
|
|
||||||
action='store_true',
|
|
||||||
help='Skip subprompt weight normalization',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-A',
|
|
||||||
'-m',
|
|
||||||
'--sampler',
|
|
||||||
dest='sampler_name',
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
choices=SAMPLER_CHOICES,
|
|
||||||
metavar='SAMPLER_NAME',
|
|
||||||
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-t',
|
|
||||||
'--log_tokenization',
|
|
||||||
action='store_true',
|
|
||||||
help='shows how the prompt is split into tokens'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-v',
|
|
||||||
'--variation_amount',
|
|
||||||
default=0.0,
|
|
||||||
type=float,
|
|
||||||
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'-V',
|
|
||||||
'--with_variations',
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
@ -6,7 +6,8 @@ import traceback
|
|||||||
import eventlet
|
import eventlet
|
||||||
import glob
|
import glob
|
||||||
import shlex
|
import shlex
|
||||||
import argparse
|
import math
|
||||||
|
import shutil
|
||||||
|
|
||||||
from flask_socketio import SocketIO
|
from flask_socketio import SocketIO
|
||||||
from flask import Flask, send_from_directory, url_for, jsonify
|
from flask import Flask, send_from_directory, url_for, jsonify
|
||||||
@ -15,13 +16,16 @@ from PIL import Image
|
|||||||
from pytorch_lightning import logging
|
from pytorch_lightning import logging
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from send2trash import send2trash
|
||||||
|
|
||||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
|
||||||
|
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
|
from ldm.dream.conditioning import split_weighted_subprompts
|
||||||
|
|
||||||
from modules.parameters import parameters_to_command, create_cmd_parser
|
from modules.parameters import parameters_to_command
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -29,12 +33,14 @@ USER CONFIG
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir = "outputs/" # Base output directory for images
|
output_dir = "outputs/" # Base output directory for images
|
||||||
#host = 'localhost' # Web & socket.io host
|
# host = 'localhost' # Web & socket.io host
|
||||||
host = '0.0.0.0' # Web & socket.io host
|
host = "localhost" # Web & socket.io host
|
||||||
port = 9090 # Web & socket.io port
|
port = 9090 # Web & socket.io port
|
||||||
verbose = False # enables copious socket.io logging
|
verbose = False # enables copious socket.io logging
|
||||||
additional_allowed_origins = ['http://localhost:9090'] # additional CORS allowed origins
|
additional_allowed_origins = [
|
||||||
|
"http://localhost:5173"
|
||||||
|
] # additional CORS allowed origins
|
||||||
|
model = "stable-diffusion-1.4"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
END USER CONFIG
|
END USER CONFIG
|
||||||
@ -47,26 +53,23 @@ SERVER SETUP
|
|||||||
|
|
||||||
|
|
||||||
# fix missing mimetypes on windows due to registry wonkiness
|
# fix missing mimetypes on windows due to registry wonkiness
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type('text/css', '.css')
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/')
|
app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/")
|
||||||
|
|
||||||
|
|
||||||
app.config['OUTPUTS_FOLDER'] = "../outputs"
|
app.config["OUTPUTS_FOLDER"] = "../outputs"
|
||||||
|
|
||||||
|
|
||||||
@app.route('/outputs/<path:filename>')
|
@app.route("/outputs/<path:filename>")
|
||||||
def outputs(filename):
|
def outputs(filename):
|
||||||
return send_from_directory(
|
return send_from_directory(app.config["OUTPUTS_FOLDER"], filename)
|
||||||
app.config['OUTPUTS_FOLDER'],
|
|
||||||
filename
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/", defaults={'path': ''})
|
@app.route("/", defaults={"path": ""})
|
||||||
def serve(path):
|
def serve(path):
|
||||||
return send_from_directory(app.static_folder, 'index.html')
|
return send_from_directory(app.static_folder, "index.html")
|
||||||
|
|
||||||
|
|
||||||
logger = True if verbose else False
|
logger = True if verbose else False
|
||||||
@ -78,12 +81,12 @@ max_http_buffer_size = 10000000
|
|||||||
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
|
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
|
||||||
|
|
||||||
socketio = SocketIO(
|
socketio = SocketIO(
|
||||||
app,
|
app,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
engineio_logger=engineio_logger,
|
engineio_logger=engineio_logger,
|
||||||
max_http_buffer_size=max_http_buffer_size,
|
max_http_buffer_size=max_http_buffer_size,
|
||||||
cors_allowed_origins=cors_allowed_origins,
|
cors_allowed_origins=cors_allowed_origins,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -104,29 +107,31 @@ canceled = Event()
|
|||||||
|
|
||||||
# reduce logging outputs to error
|
# reduce logging outputs to error
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Initialize and load model
|
# Initialize and load model
|
||||||
model = Generate()
|
generate = Generate(model)
|
||||||
model.load_model()
|
generate.load_model()
|
||||||
|
|
||||||
|
|
||||||
# location for "finished" images
|
# location for "finished" images
|
||||||
result_path = os.path.join(output_dir, 'img-samples/')
|
result_path = os.path.join(output_dir, "img-samples/")
|
||||||
|
|
||||||
# temporary path for intermediates
|
# temporary path for intermediates
|
||||||
intermediate_path = os.path.join(result_path, 'intermediates/')
|
intermediate_path = os.path.join(result_path, "intermediates/")
|
||||||
|
|
||||||
# path for user-uploaded init images and masks
|
# path for user-uploaded init images and masks
|
||||||
init_path = os.path.join(result_path, 'init-images/')
|
init_image_path = os.path.join(result_path, "init-images/")
|
||||||
mask_path = os.path.join(result_path, 'mask-images/')
|
mask_image_path = os.path.join(result_path, "mask-images/")
|
||||||
|
|
||||||
# txt log
|
# txt log
|
||||||
log_path = os.path.join(result_path, 'dream_log.txt')
|
log_path = os.path.join(result_path, "dream_log.txt")
|
||||||
|
|
||||||
# make all output paths
|
# make all output paths
|
||||||
[os.makedirs(path, exist_ok=True)
|
[
|
||||||
for path in [result_path, intermediate_path, init_path, mask_path]]
|
os.makedirs(path, exist_ok=True)
|
||||||
|
for path in [result_path, intermediate_path, init_image_path, mask_image_path]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -139,126 +144,219 @@ SOCKET.IO LISTENERS
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@socketio.on('requestAllImages')
|
@socketio.on("requestSystemConfig")
|
||||||
|
def handle_request_capabilities():
|
||||||
|
print(f">> System config requested")
|
||||||
|
config = get_system_config()
|
||||||
|
socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on("requestAllImages")
|
||||||
def handle_request_all_images():
|
def handle_request_all_images():
|
||||||
print(f'>> All images requested')
|
print(f">> All images requested")
|
||||||
parser = create_cmd_parser()
|
|
||||||
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
|
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
|
||||||
paths.sort(key=lambda x: os.path.getmtime(x))
|
paths.sort(key=lambda x: os.path.getmtime(x))
|
||||||
image_array = []
|
image_array = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
# image = Image.open(path)
|
metadata = retrieve_metadata(path)
|
||||||
all_metadata = retrieve_metadata(path)
|
image_array.append({"url": path, "metadata": metadata["sd-metadata"]})
|
||||||
if 'Dream' in all_metadata and not all_metadata['sd-metadata']:
|
socketio.emit("galleryImages", {"images": image_array})
|
||||||
metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream'])))
|
eventlet.sleep(0)
|
||||||
else:
|
|
||||||
metadata = all_metadata['sd-metadata']
|
|
||||||
image_array.append({'path': path, 'metadata': metadata})
|
|
||||||
return make_response("OK", data=image_array)
|
|
||||||
|
|
||||||
|
|
||||||
@socketio.on('generateImage')
|
@socketio.on("generateImage")
|
||||||
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
def handle_generate_image_event(
|
||||||
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}')
|
generation_parameters, esrgan_parameters, gfpgan_parameters
|
||||||
generate_images(
|
):
|
||||||
generation_parameters,
|
print(
|
||||||
esrgan_parameters,
|
f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}"
|
||||||
gfpgan_parameters
|
|
||||||
)
|
)
|
||||||
return make_response("OK")
|
generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
||||||
|
|
||||||
|
|
||||||
@socketio.on('runESRGAN')
|
@socketio.on("runESRGAN")
|
||||||
def handle_run_esrgan_event(original_image, esrgan_parameters):
|
def handle_run_esrgan_event(original_image, esrgan_parameters):
|
||||||
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
|
print(
|
||||||
|
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
|
||||||
|
)
|
||||||
|
progress = {
|
||||||
|
"currentStep": 1,
|
||||||
|
"totalSteps": 1,
|
||||||
|
"currentIteration": 1,
|
||||||
|
"totalIterations": 1,
|
||||||
|
"currentStatus": "Preparing",
|
||||||
|
"isProcessing": True,
|
||||||
|
"currentStatusHasSteps": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = Image.open(original_image["url"])
|
image = Image.open(original_image["url"])
|
||||||
|
|
||||||
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
seed = (
|
||||||
|
original_image["metadata"]["seed"]
|
||||||
|
if "seed" in original_image["metadata"]
|
||||||
|
else "unknown_seed"
|
||||||
|
)
|
||||||
|
|
||||||
|
progress["currentStatus"] = "Upscaling"
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = real_esrgan_upscale(
|
image = real_esrgan_upscale(
|
||||||
image=image,
|
image=image,
|
||||||
upsampler_scale=esrgan_parameters['upscale'][0],
|
upsampler_scale=esrgan_parameters["upscale"][0],
|
||||||
strength=esrgan_parameters['upscale'][1],
|
strength=esrgan_parameters["upscale"][1],
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
esrgan_parameters['seed'] = seed
|
progress["currentStatus"] = "Saving image"
|
||||||
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
esrgan_parameters["seed"] = seed
|
||||||
|
metadata = parameters_to_post_processed_image_metadata(
|
||||||
|
parameters=esrgan_parameters,
|
||||||
|
original_image_path=original_image["url"],
|
||||||
|
type="esrgan",
|
||||||
|
)
|
||||||
command = parameters_to_command(esrgan_parameters)
|
command = parameters_to_command(esrgan_parameters)
|
||||||
|
|
||||||
|
path = save_image(image, command, metadata, result_path, postprocessing="esrgan")
|
||||||
|
|
||||||
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
|
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
|
||||||
|
|
||||||
|
progress["currentStatus"] = "Finished"
|
||||||
|
progress["currentStep"] = 0
|
||||||
|
progress["totalSteps"] = 0
|
||||||
|
progress["currentIteration"] = 0
|
||||||
|
progress["totalIterations"] = 0
|
||||||
|
progress["isProcessing"] = False
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
|
"esrganResult",
|
||||||
|
{
|
||||||
|
"url": os.path.relpath(path),
|
||||||
|
"metadata": metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@socketio.on("runGFPGAN")
|
||||||
@socketio.on('runGFPGAN')
|
|
||||||
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
|
||||||
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
|
print(
|
||||||
|
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
|
||||||
|
)
|
||||||
|
progress = {
|
||||||
|
"currentStep": 1,
|
||||||
|
"totalSteps": 1,
|
||||||
|
"currentIteration": 1,
|
||||||
|
"totalIterations": 1,
|
||||||
|
"currentStatus": "Preparing",
|
||||||
|
"isProcessing": True,
|
||||||
|
"currentStatusHasSteps": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = Image.open(original_image["url"])
|
image = Image.open(original_image["url"])
|
||||||
|
|
||||||
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
|
seed = (
|
||||||
|
original_image["metadata"]["seed"]
|
||||||
|
if "seed" in original_image["metadata"]
|
||||||
|
else "unknown_seed"
|
||||||
|
)
|
||||||
|
|
||||||
|
progress["currentStatus"] = "Fixing faces"
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = run_gfpgan(
|
image = run_gfpgan(
|
||||||
image=image,
|
image=image,
|
||||||
strength=gfpgan_parameters['gfpgan_strength'],
|
strength=gfpgan_parameters["gfpgan_strength"],
|
||||||
seed=seed,
|
seed=seed,
|
||||||
upsampler_scale=1
|
upsampler_scale=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
gfpgan_parameters['seed'] = seed
|
progress["currentStatus"] = "Saving image"
|
||||||
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
|
gfpgan_parameters["seed"] = seed
|
||||||
|
metadata = parameters_to_post_processed_image_metadata(
|
||||||
|
parameters=gfpgan_parameters,
|
||||||
|
original_image_path=original_image["url"],
|
||||||
|
type="gfpgan",
|
||||||
|
)
|
||||||
command = parameters_to_command(gfpgan_parameters)
|
command = parameters_to_command(gfpgan_parameters)
|
||||||
|
|
||||||
|
path = save_image(image, command, metadata, result_path, postprocessing="gfpgan")
|
||||||
|
|
||||||
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
|
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
|
||||||
|
|
||||||
|
progress["currentStatus"] = "Finished"
|
||||||
|
progress["currentStep"] = 0
|
||||||
|
progress["totalSteps"] = 0
|
||||||
|
progress["currentIteration"] = 0
|
||||||
|
progress["totalIterations"] = 0
|
||||||
|
progress["isProcessing"] = False
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
|
"gfpganResult",
|
||||||
|
{
|
||||||
|
"url": os.path.relpath(path),
|
||||||
|
"metadata": metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@socketio.on('cancel')
|
@socketio.on("cancel")
|
||||||
def handle_cancel():
|
def handle_cancel():
|
||||||
print(f'>> Cancel processing requested')
|
print(f">> Cancel processing requested")
|
||||||
canceled.set()
|
canceled.set()
|
||||||
return make_response("OK")
|
socketio.emit("processingCanceled")
|
||||||
|
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on('deleteImage')
|
@socketio.on("deleteImage")
|
||||||
def handle_delete_image(path):
|
def handle_delete_image(path, uuid):
|
||||||
print(f'>> Delete requested "{path}"')
|
print(f'>> Delete requested "{path}"')
|
||||||
Path(path).unlink()
|
send2trash(path)
|
||||||
return make_response("OK")
|
socketio.emit("imageDeleted", {"url": path, "uuid": uuid})
|
||||||
|
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on('uploadInitialImage')
|
@socketio.on("uploadInitialImage")
|
||||||
def handle_upload_initial_image(bytes, name):
|
def handle_upload_initial_image(bytes, name):
|
||||||
print(f'>> Init image upload requested "{name}"')
|
print(f'>> Init image upload requested "{name}"')
|
||||||
uuid = uuid4().hex
|
uuid = uuid4().hex
|
||||||
split = os.path.splitext(name)
|
split = os.path.splitext(name)
|
||||||
name = f'{split[0]}.{uuid}{split[1]}'
|
name = f"{split[0]}.{uuid}{split[1]}"
|
||||||
file_path = os.path.join(init_path, name)
|
file_path = os.path.join(init_image_path, name)
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
newFile = open(file_path, "wb")
|
newFile = open(file_path, "wb")
|
||||||
newFile.write(bytes)
|
newFile.write(bytes)
|
||||||
return make_response("OK", data=file_path)
|
socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""})
|
||||||
|
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on('uploadMaskImage')
|
@socketio.on("uploadMaskImage")
|
||||||
def handle_upload_mask_image(bytes, name):
|
def handle_upload_mask_image(bytes, name):
|
||||||
print(f'>> Mask image upload requested "{name}"')
|
print(f'>> Mask image upload requested "{name}"')
|
||||||
uuid = uuid4().hex
|
uuid = uuid4().hex
|
||||||
split = os.path.splitext(name)
|
split = os.path.splitext(name)
|
||||||
name = f'{split[0]}.{uuid}{split[1]}'
|
name = f"{split[0]}.{uuid}{split[1]}"
|
||||||
file_path = os.path.join(mask_path, name)
|
file_path = os.path.join(mask_image_path, name)
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
newFile = open(file_path, "wb")
|
newFile = open(file_path, "wb")
|
||||||
newFile.write(bytes)
|
newFile.write(bytes)
|
||||||
return make_response("OK", data=file_path)
|
socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -266,114 +364,343 @@ END SOCKET.IO LISTENERS
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ADDITIONAL FUNCTIONS
|
ADDITIONAL FUNCTIONS
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_config():
|
||||||
|
return {
|
||||||
|
"model": "stable diffusion",
|
||||||
|
"model_id": model,
|
||||||
|
"model_hash": generate.model_hash,
|
||||||
|
"app_id": APP_ID,
|
||||||
|
"app_version": APP_VERSION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parameters_to_post_processed_image_metadata(parameters, original_image_path, type):
|
||||||
|
# top-level metadata minus `image` or `images`
|
||||||
|
metadata = get_system_config()
|
||||||
|
|
||||||
|
orig_hash = calculate_init_img_hash(original_image_path)
|
||||||
|
|
||||||
|
image = {"orig_path": original_image_path, "orig_hash": orig_hash}
|
||||||
|
|
||||||
|
if type == "esrgan":
|
||||||
|
image["type"] = "esrgan"
|
||||||
|
image["scale"] = parameters["upscale"][0]
|
||||||
|
image["strength"] = parameters["upscale"][1]
|
||||||
|
elif type == "gfpgan":
|
||||||
|
image["type"] = "gfpgan"
|
||||||
|
image["strength"] = parameters["gfpgan_strength"]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid type: {type}")
|
||||||
|
|
||||||
|
metadata["image"] = image
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def parameters_to_generated_image_metadata(parameters):
|
||||||
|
# top-level metadata minus `image` or `images`
|
||||||
|
|
||||||
|
metadata = get_system_config()
|
||||||
|
# remove any image keys not mentioned in RFC #266
|
||||||
|
rfc266_img_fields = [
|
||||||
|
"type",
|
||||||
|
"postprocessing",
|
||||||
|
"sampler",
|
||||||
|
"prompt",
|
||||||
|
"seed",
|
||||||
|
"variations",
|
||||||
|
"steps",
|
||||||
|
"cfg_scale",
|
||||||
|
"step_number",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"extra",
|
||||||
|
"seamless",
|
||||||
|
]
|
||||||
|
|
||||||
|
rfc_dict = {}
|
||||||
|
|
||||||
|
for item in parameters.items():
|
||||||
|
key, value = item
|
||||||
|
if key in rfc266_img_fields:
|
||||||
|
rfc_dict[key] = value
|
||||||
|
|
||||||
|
postprocessing = []
|
||||||
|
|
||||||
|
# 'postprocessing' is either null or an
|
||||||
|
if "gfpgan_strength" in parameters:
|
||||||
|
|
||||||
|
postprocessing.append(
|
||||||
|
{"type": "gfpgan", "strength": float(parameters["gfpgan_strength"])}
|
||||||
|
)
|
||||||
|
|
||||||
|
if "upscale" in parameters:
|
||||||
|
postprocessing.append(
|
||||||
|
{
|
||||||
|
"type": "esrgan",
|
||||||
|
"scale": int(parameters["upscale"][0]),
|
||||||
|
"strength": float(parameters["upscale"][1]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
|
||||||
|
|
||||||
|
# semantic drift
|
||||||
|
rfc_dict["sampler"] = parameters["sampler_name"]
|
||||||
|
|
||||||
|
# display weighted subprompts (liable to change)
|
||||||
|
subprompts = split_weighted_subprompts(parameters["prompt"])
|
||||||
|
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
|
||||||
|
rfc_dict["prompt"] = subprompts
|
||||||
|
|
||||||
|
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
|
||||||
|
variations = []
|
||||||
|
|
||||||
|
if "with_variations" in parameters:
|
||||||
|
variations = [
|
||||||
|
{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
rfc_dict["variations"] = variations
|
||||||
|
|
||||||
|
if "init_img" in parameters:
|
||||||
|
rfc_dict["type"] = "img2img"
|
||||||
|
rfc_dict["strength"] = parameters["strength"]
|
||||||
|
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
|
||||||
|
rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"])
|
||||||
|
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
|
||||||
|
rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||||
|
if "init_mask" in parameters:
|
||||||
|
rfc_dict["mask_hash"] = calculate_init_img_hash(
|
||||||
|
parameters["init_mask"]
|
||||||
|
) # TODO: Noncompliant
|
||||||
|
rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant
|
||||||
|
else:
|
||||||
|
rfc_dict["type"] = "txt2img"
|
||||||
|
|
||||||
|
metadata["image"] = rfc_dict
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def make_unique_init_image_filename(name):
|
||||||
|
uuid = uuid4().hex
|
||||||
|
split = os.path.splitext(name)
|
||||||
|
name = f"{split[0]}.{uuid}{split[1]}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
def write_log_message(message, log_path=log_path):
|
def write_log_message(message, log_path=log_path):
|
||||||
"""Logs the filename and parameters used to generate or process that image to log file"""
|
"""Logs the filename and parameters used to generate or process that image to log file"""
|
||||||
message = f'{message}\n'
|
message = f"{message}\n"
|
||||||
with open(log_path, 'a', encoding='utf-8') as file:
|
with open(log_path, "a", encoding="utf-8") as file:
|
||||||
file.writelines(message)
|
file.writelines(message)
|
||||||
|
|
||||||
|
|
||||||
def make_response(status, message=None, data=None):
|
def save_image(
|
||||||
response = {'status': status}
|
image, command, metadata, output_dir, step_index=None, postprocessing=False
|
||||||
if message is not None:
|
):
|
||||||
response['message'] = message
|
|
||||||
if data is not None:
|
|
||||||
response['data'] = data
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
|
|
||||||
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
|
|
||||||
|
|
||||||
pngwriter = PngWriter(output_dir)
|
pngwriter = PngWriter(output_dir)
|
||||||
prefix = pngwriter.unique_prefix()
|
prefix = pngwriter.unique_prefix()
|
||||||
|
|
||||||
filename = f'{prefix}.{seed}'
|
seed = "unknown_seed"
|
||||||
|
|
||||||
|
if "image" in metadata:
|
||||||
|
if "seed" in metadata["image"]:
|
||||||
|
seed = metadata["image"]["seed"]
|
||||||
|
|
||||||
|
filename = f"{prefix}.{seed}"
|
||||||
|
|
||||||
if step_index:
|
if step_index:
|
||||||
filename += f'.{step_index}'
|
filename += f".{step_index}"
|
||||||
if postprocessing:
|
if postprocessing:
|
||||||
filename += f'.postprocessed'
|
filename += f".postprocessed"
|
||||||
|
|
||||||
filename += '.png'
|
filename += ".png"
|
||||||
|
|
||||||
command = parameters_to_command(parameters)
|
path = pngwriter.save_image_and_prompt_to_png(
|
||||||
|
image=image, dream_prompt=command, metadata=metadata, name=filename
|
||||||
path = pngwriter.save_image_and_prompt_to_png(image, command, metadata=parameters, name=filename)
|
)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_real_steps(steps, strength, has_init_image):
|
||||||
|
return math.floor(strength * steps) if has_init_image else steps
|
||||||
|
|
||||||
|
|
||||||
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
|
||||||
canceled.clear()
|
canceled.clear()
|
||||||
|
|
||||||
step_index = 1
|
step_index = 1
|
||||||
|
|
||||||
|
"""
|
||||||
|
If a result image is used as an init image, and then deleted, we will want to be
|
||||||
|
able to use it as an init image in the future. Need to copy it.
|
||||||
|
|
||||||
|
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
|
||||||
|
make a unique filename for it and copy it there.
|
||||||
|
"""
|
||||||
|
if "init_img" in generation_parameters:
|
||||||
|
filename = os.path.basename(generation_parameters["init_img"])
|
||||||
|
if not os.path.exists(os.path.join(init_image_path, filename)):
|
||||||
|
unique_filename = make_unique_init_image_filename(filename)
|
||||||
|
new_path = os.path.join(init_image_path, unique_filename)
|
||||||
|
shutil.copy(generation_parameters["init_img"], new_path)
|
||||||
|
generation_parameters["init_img"] = new_path
|
||||||
|
if "init_mask" in generation_parameters:
|
||||||
|
filename = os.path.basename(generation_parameters["init_mask"])
|
||||||
|
if not os.path.exists(os.path.join(mask_image_path, filename)):
|
||||||
|
unique_filename = make_unique_init_image_filename(filename)
|
||||||
|
new_path = os.path.join(init_image_path, unique_filename)
|
||||||
|
shutil.copy(generation_parameters["init_img"], new_path)
|
||||||
|
generation_parameters["init_mask"] = new_path
|
||||||
|
|
||||||
|
totalSteps = calculate_real_steps(
|
||||||
|
steps=generation_parameters["steps"],
|
||||||
|
strength=generation_parameters["strength"]
|
||||||
|
if "strength" in generation_parameters
|
||||||
|
else None,
|
||||||
|
has_init_image="init_img" in generation_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
progress = {
|
||||||
|
"currentStep": 1,
|
||||||
|
"totalSteps": totalSteps,
|
||||||
|
"currentIteration": 1,
|
||||||
|
"totalIterations": generation_parameters["iterations"],
|
||||||
|
"currentStatus": "Preparing",
|
||||||
|
"isProcessing": True,
|
||||||
|
"currentStatusHasSteps": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
def image_progress(sample, step):
|
def image_progress(sample, step):
|
||||||
if canceled.is_set():
|
if canceled.is_set():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
nonlocal step_index
|
nonlocal step_index
|
||||||
nonlocal generation_parameters
|
nonlocal generation_parameters
|
||||||
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
|
nonlocal progress
|
||||||
image = model.sample_to_image(sample)
|
|
||||||
path = save_image(image, generation_parameters, intermediate_path, step_index)
|
progress["currentStep"] = step + 1
|
||||||
|
progress["currentStatus"] = "Generating"
|
||||||
|
progress["currentStatusHasSteps"] = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
generation_parameters["progress_images"]
|
||||||
|
and step % 5 == 0
|
||||||
|
and step < generation_parameters["steps"] - 1
|
||||||
|
):
|
||||||
|
image = generate.sample_to_image(sample)
|
||||||
|
path = save_image(
|
||||||
|
image, generation_parameters, intermediate_path, step_index
|
||||||
|
)
|
||||||
|
|
||||||
step_index += 1
|
step_index += 1
|
||||||
socketio.emit('intermediateResult', {
|
socketio.emit(
|
||||||
'url': os.path.relpath(path), 'metadata': generation_parameters})
|
"intermediateResult",
|
||||||
socketio.emit('progress', {'step': step + 1})
|
{"url": os.path.relpath(path), "metadata": generation_parameters},
|
||||||
|
)
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
def image_done(image, seed):
|
def image_done(image, seed):
|
||||||
nonlocal generation_parameters
|
nonlocal generation_parameters
|
||||||
nonlocal esrgan_parameters
|
nonlocal esrgan_parameters
|
||||||
nonlocal gfpgan_parameters
|
nonlocal gfpgan_parameters
|
||||||
|
nonlocal progress
|
||||||
|
|
||||||
|
step_index = 1
|
||||||
|
|
||||||
|
progress["currentStatus"] = "Generation complete"
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
all_parameters = generation_parameters
|
all_parameters = generation_parameters
|
||||||
postprocessing = False
|
postprocessing = False
|
||||||
|
|
||||||
if esrgan_parameters:
|
if esrgan_parameters:
|
||||||
|
progress["currentStatus"] = "Upscaling"
|
||||||
|
progress["currentStatusHasSteps"] = False
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = real_esrgan_upscale(
|
image = real_esrgan_upscale(
|
||||||
image=image,
|
image=image,
|
||||||
strength=esrgan_parameters['strength'],
|
strength=esrgan_parameters["strength"],
|
||||||
upsampler_scale=esrgan_parameters['level'],
|
upsampler_scale=esrgan_parameters["level"],
|
||||||
seed=seed
|
seed=seed,
|
||||||
)
|
)
|
||||||
postprocessing = True
|
postprocessing = True
|
||||||
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
|
all_parameters["upscale"] = [
|
||||||
|
esrgan_parameters["level"],
|
||||||
|
esrgan_parameters["strength"],
|
||||||
|
]
|
||||||
|
|
||||||
if gfpgan_parameters:
|
if gfpgan_parameters:
|
||||||
|
progress["currentStatus"] = "Fixing faces"
|
||||||
|
progress["currentStatusHasSteps"] = False
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
image = run_gfpgan(
|
image = run_gfpgan(
|
||||||
image=image,
|
image=image,
|
||||||
strength=gfpgan_parameters['strength'],
|
strength=gfpgan_parameters["strength"],
|
||||||
seed=seed,
|
seed=seed,
|
||||||
upsampler_scale=1,
|
upsampler_scale=1,
|
||||||
)
|
)
|
||||||
postprocessing = True
|
postprocessing = True
|
||||||
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
|
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
|
||||||
|
|
||||||
all_parameters['seed'] = seed
|
all_parameters["seed"] = seed
|
||||||
|
progress["currentStatus"] = "Saving image"
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
|
metadata = parameters_to_generated_image_metadata(all_parameters)
|
||||||
command = parameters_to_command(all_parameters)
|
command = parameters_to_command(all_parameters)
|
||||||
|
|
||||||
print(f'Image generated: "{path}"')
|
path = save_image(
|
||||||
|
image, command, metadata, result_path, postprocessing=postprocessing
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'>> Image generated: "{path}"')
|
||||||
write_log_message(f'[Generated] "{path}": {command}')
|
write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
|
if progress["totalIterations"] > progress["currentIteration"]:
|
||||||
|
progress["currentStep"] = 1
|
||||||
|
progress["currentIteration"] += 1
|
||||||
|
progress["currentStatus"] = "Iteration finished"
|
||||||
|
progress["currentStatusHasSteps"] = False
|
||||||
|
else:
|
||||||
|
progress["currentStep"] = 0
|
||||||
|
progress["totalSteps"] = 0
|
||||||
|
progress["currentIteration"] = 0
|
||||||
|
progress["totalIterations"] = 0
|
||||||
|
progress["currentStatus"] = "Finished"
|
||||||
|
progress["isProcessing"] = False
|
||||||
|
|
||||||
|
socketio.emit("progressUpdate", progress)
|
||||||
|
eventlet.sleep(0)
|
||||||
|
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
|
"generationResult",
|
||||||
|
{"url": os.path.relpath(path), "metadata": metadata},
|
||||||
|
)
|
||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.prompt2image(
|
generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=image_progress,
|
step_callback=image_progress,
|
||||||
image_callback=image_done
|
image_callback=image_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -381,7 +708,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
|
|||||||
except CanceledException:
|
except CanceledException:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
socketio.emit('error', (str(e)))
|
socketio.emit("error", {"message": (str(e))})
|
||||||
print("\n")
|
print("\n")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print("\n")
|
print("\n")
|
||||||
@ -392,6 +719,6 @@ END ADDITIONAL FUNCTIONS
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
print(f'Starting server at http://{host}:{port}')
|
print(f">> Starting server at http://{host}:{port}")
|
||||||
socketio.run(app, host=host, port=port)
|
socketio.run(app, host=host, port=port)
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 22 KiB |
@ -59,9 +59,7 @@ Once the model is trained, specify the trained .pt or .bin file when starting
|
|||||||
dream using
|
dream using
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 ./scripts/dream.py \
|
python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt
|
||||||
--embedding_path /path/to/embedding.pt \
|
|
||||||
--full_precision
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, to utilize your subject at the dream prompt
|
Then, to utilize your subject at the dream prompt
|
||||||
|
@ -97,6 +97,11 @@ You wil need one of the following:
|
|||||||
```
|
```
|
||||||
## :octicons-log-16: Latest Changes
|
## :octicons-log-16: Latest Changes
|
||||||
|
|
||||||
|
### vNEXT <small>(TODO 2022)</small>
|
||||||
|
|
||||||
|
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
|
||||||
|
configure. To switch away from auto use the new flag like `--precision=float32`.
|
||||||
|
|
||||||
### v1.14 <small>(11 September 2022)</small>
|
### v1.14 <small>(11 September 2022)</small>
|
||||||
|
|
||||||
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.
|
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.
|
||||||
|
@ -106,7 +106,6 @@ PATH_TO_CKPT="$HOME/Downloads" # (1)!
|
|||||||
|
|
||||||
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" \
|
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" \
|
||||||
models/ldm/stable-diffusion-v1/model.ckpt
|
models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
1. or wherever you saved sd-v1-4.ckpt
|
1. or wherever you saved sd-v1-4.ckpt
|
||||||
@ -548,5 +547,3 @@ Abort trap: 6
|
|||||||
warnings.warn('resource_tracker: There appear to be %d '
|
warnings.warn('resource_tracker: There appear to be %d '
|
||||||
```
|
```
|
||||||
|
|
||||||
Macs do not support `autocast/mixed-precision`, so you need to supply
|
|
||||||
`--full_precision` to use float32 everywhere.
|
|
||||||
|
@ -32,6 +32,7 @@ dependencies:
|
|||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- onnx==1.12.0
|
- onnx==1.12.0
|
||||||
- onnxruntime==1.12.1
|
- onnxruntime==1.12.1
|
||||||
|
- protobuf==3.20.1
|
||||||
- pudb==2022.1
|
- pudb==2022.1
|
||||||
- pytorch-lightning==1.6.5
|
- pytorch-lightning==1.6.5
|
||||||
- scipy==1.9.1
|
- scipy==1.9.1
|
||||||
@ -48,6 +49,7 @@ dependencies:
|
|||||||
- opencv-python==4.6.0
|
- opencv-python==4.6.0
|
||||||
- protobuf==3.20.1
|
- protobuf==3.20.1
|
||||||
- realesrgan==0.2.5.0
|
- realesrgan==0.2.5.0
|
||||||
|
- send2trash==1.8.0
|
||||||
- test-tube==0.7.5
|
- test-tube==0.7.5
|
||||||
- transformers==4.21.2
|
- transformers==4.21.2
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
|
@ -20,6 +20,7 @@ dependencies:
|
|||||||
- realesrgan==0.2.5.0
|
- realesrgan==0.2.5.0
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- streamlit==1.12.0
|
- streamlit==1.12.0
|
||||||
|
- send2trash==1.8.0
|
||||||
- pillow==9.2.0
|
- pillow==9.2.0
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
|
@ -1,85 +1,37 @@
|
|||||||
# Stable Diffusion Web UI
|
# Stable Diffusion Web UI
|
||||||
|
|
||||||
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
|
## Run
|
||||||
|
|
||||||
much of this readme is just notes for myself during dev work
|
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
|
||||||
|
|
||||||
numpy rand: 0 to 4294967295
|
## Evironment
|
||||||
|
|
||||||
## Test and Build
|
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
|
||||||
|
[yarn](https://yarnpkg.com/getting-started/install).
|
||||||
|
|
||||||
from `frontend/`:
|
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
|
||||||
|
|
||||||
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
|
## Dev
|
||||||
|
|
||||||
from `.`:
|
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
|
||||||
|
2. Note the address it starts up on (probably `http://localhost:5173/`).
|
||||||
|
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
|
||||||
|
`additional_allowed_origins = ['http://localhost:5173']`.
|
||||||
|
4. Leaving the dev server running, open a new terminal and go to the project root.
|
||||||
|
5. Run `python backend/server.py`.
|
||||||
|
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
|
||||||
|
|
||||||
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
|
To build for dev: `npm build-dev` / `yarn build-dev`
|
||||||
|
|
||||||
## API
|
To build for production: `npm build` / `yarn build`
|
||||||
|
|
||||||
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
|
|
||||||
|
|
||||||
### Server Listeners
|
|
||||||
|
|
||||||
The server listens for these socket.io events:
|
|
||||||
|
|
||||||
`cancel`
|
|
||||||
|
|
||||||
- Cancels in-progress image generation
|
|
||||||
- Returns ack only
|
|
||||||
|
|
||||||
`generateImage`
|
|
||||||
|
|
||||||
- Accepts object of image parameters
|
|
||||||
- Generates an image
|
|
||||||
- Returns ack only (image generation function sends progress and result via separate events)
|
|
||||||
|
|
||||||
`deleteImage`
|
|
||||||
|
|
||||||
- Accepts file path to image
|
|
||||||
- Deletes image
|
|
||||||
- Returns ack only
|
|
||||||
|
|
||||||
`deleteAllImages` WIP
|
|
||||||
|
|
||||||
- Deletes all images in `outputs/`
|
|
||||||
- Returns ack only
|
|
||||||
|
|
||||||
`requestAllImages`
|
|
||||||
|
|
||||||
- Returns array of all images in `outputs/`
|
|
||||||
|
|
||||||
`requestCapabilities` WIP
|
|
||||||
|
|
||||||
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
|
|
||||||
|
|
||||||
`sendImage` WIP
|
|
||||||
|
|
||||||
- Accepts a File and attributes
|
|
||||||
- Saves image
|
|
||||||
- Used to save init images which are not generated images
|
|
||||||
|
|
||||||
### Server Emitters
|
|
||||||
|
|
||||||
`progress`
|
|
||||||
|
|
||||||
- Emitted during each step in generation
|
|
||||||
- Sends a number from 0 to 1 representing percentage of steps completed
|
|
||||||
|
|
||||||
`result` WIP
|
|
||||||
|
|
||||||
- Emitted when an image generation has completed
|
|
||||||
- Sends a object:
|
|
||||||
|
|
||||||
```
|
|
||||||
{
|
|
||||||
url: relative_file_path,
|
|
||||||
metadata: image_metadata_object
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- Search repo for "TODO"
|
- Search repo for "TODO"
|
||||||
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.
|
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on
|
||||||
|
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
|
||||||
|
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
|
||||||
|
this. Need to check in on this issue periodically.
|
||||||
|
- Mobile friendly layout
|
||||||
|
- Proper image gallery/viewer/manager
|
||||||
|
- Help tooltips and such
|
||||||
|
694
frontend/dist/assets/index.727a397b.js
vendored
Normal file
694
frontend/dist/assets/index.727a397b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
695
frontend/dist/assets/index.cc5cde43.js
vendored
695
frontend/dist/assets/index.cc5cde43.js
vendored
File diff suppressed because one or more lines are too long
2
frontend/dist/index.html
vendored
2
frontend/dist/index.html
vendored
@ -4,7 +4,7 @@
|
|||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Stable Diffusion Dream Server</title>
|
<title>Stable Diffusion Dream Server</title>
|
||||||
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
|
<script type="module" crossorigin src="/assets/index.727a397b.js"></script>
|
||||||
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Stable Diffusion Dream Server</title>
|
<title>InvokeAI Stable Diffusion Dream Server</title>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="root"></div>
|
<div id="root"></div>
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
{
|
{
|
||||||
"name": "sdui",
|
"name": "invoke-ai-ui",
|
||||||
"private": true,
|
"private": true,
|
||||||
"version": "0.0.0",
|
"version": "0.0.1",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
|
"dev": "vite dev",
|
||||||
"hmr": "vite dev",
|
|
||||||
"build": "tsc && vite build",
|
"build": "tsc && vite build",
|
||||||
"build-dev": "tsc && vite build -m development",
|
"build-dev": "tsc && vite build -m development",
|
||||||
"preview": "vite preview"
|
"preview": "vite preview"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@chakra-ui/icons": "^2.0.10",
|
||||||
"@chakra-ui/react": "^2.3.1",
|
"@chakra-ui/react": "^2.3.1",
|
||||||
"@emotion/react": "^11.10.4",
|
"@emotion/react": "^11.10.4",
|
||||||
"@emotion/styled": "^11.10.4",
|
"@emotion/styled": "^11.10.4",
|
||||||
|
@ -1,60 +0,0 @@
|
|||||||
import { Grid, GridItem } from '@chakra-ui/react';
|
|
||||||
import CurrentImage from './features/gallery/CurrentImage';
|
|
||||||
import LogViewer from './features/system/LogViewer';
|
|
||||||
import PromptInput from './features/sd/PromptInput';
|
|
||||||
import ProgressBar from './features/header/ProgressBar';
|
|
||||||
import { useEffect } from 'react';
|
|
||||||
import { useAppDispatch } from './app/hooks';
|
|
||||||
import { requestAllImages } from './app/socketio';
|
|
||||||
import ProcessButtons from './features/sd/ProcessButtons';
|
|
||||||
import ImageRoll from './features/gallery/ImageRoll';
|
|
||||||
import SiteHeader from './features/header/SiteHeader';
|
|
||||||
import OptionsAccordion from './features/sd/OptionsAccordion';
|
|
||||||
|
|
||||||
const App = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
useEffect(() => {
|
|
||||||
dispatch(requestAllImages());
|
|
||||||
}, [dispatch]);
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<Grid
|
|
||||||
width='100vw'
|
|
||||||
height='100vh'
|
|
||||||
templateAreas={`
|
|
||||||
"header header header header"
|
|
||||||
"progressBar progressBar progressBar progressBar"
|
|
||||||
"menu prompt processButtons imageRoll"
|
|
||||||
"menu currentImage currentImage imageRoll"`}
|
|
||||||
gridTemplateRows={'36px 10px 100px auto'}
|
|
||||||
gridTemplateColumns={'350px auto 100px 388px'}
|
|
||||||
gap={2}
|
|
||||||
>
|
|
||||||
<GridItem area={'header'} pt={1}>
|
|
||||||
<SiteHeader />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area={'progressBar'}>
|
|
||||||
<ProgressBar />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem pl='2' area={'menu'} overflowY='scroll'>
|
|
||||||
<OptionsAccordion />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area={'prompt'}>
|
|
||||||
<PromptInput />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area={'processButtons'}>
|
|
||||||
<ProcessButtons />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem area={'currentImage'}>
|
|
||||||
<CurrentImage />
|
|
||||||
</GridItem>
|
|
||||||
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
|
|
||||||
<ImageRoll />
|
|
||||||
</GridItem>
|
|
||||||
</Grid>
|
|
||||||
<LogViewer />
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default App;
|
|
69
frontend/src/app/App.tsx
Normal file
69
frontend/src/app/App.tsx
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import { Grid, GridItem } from '@chakra-ui/react';
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
|
||||||
|
import ImageGallery from '../features/gallery/ImageGallery';
|
||||||
|
import ProgressBar from '../features/system/ProgressBar';
|
||||||
|
import SiteHeader from '../features/system/SiteHeader';
|
||||||
|
import OptionsAccordion from '../features/options/OptionsAccordion';
|
||||||
|
import ProcessButtons from '../features/options/ProcessButtons';
|
||||||
|
import PromptInput from '../features/options/PromptInput';
|
||||||
|
import LogViewer from '../features/system/LogViewer';
|
||||||
|
import Loading from '../Loading';
|
||||||
|
import { useAppDispatch } from './store';
|
||||||
|
import { requestAllImages, requestSystemConfig } from './socketio/actions';
|
||||||
|
|
||||||
|
const App = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [isReady, setIsReady] = useState<boolean>(false);
|
||||||
|
|
||||||
|
// Load images from the gallery once
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(requestAllImages());
|
||||||
|
dispatch(requestSystemConfig());
|
||||||
|
setIsReady(true);
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return isReady ? (
|
||||||
|
<>
|
||||||
|
<Grid
|
||||||
|
width="100vw"
|
||||||
|
height="100vh"
|
||||||
|
templateAreas={`
|
||||||
|
"header header header header"
|
||||||
|
"progressBar progressBar progressBar progressBar"
|
||||||
|
"menu prompt processButtons imageRoll"
|
||||||
|
"menu currentImage currentImage imageRoll"`}
|
||||||
|
gridTemplateRows={'36px 10px 100px auto'}
|
||||||
|
gridTemplateColumns={'350px auto 100px 388px'}
|
||||||
|
gap={2}
|
||||||
|
>
|
||||||
|
<GridItem area={'header'} pt={1}>
|
||||||
|
<SiteHeader />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'progressBar'}>
|
||||||
|
<ProgressBar />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem pl="2" area={'menu'} overflowY="scroll">
|
||||||
|
<OptionsAccordion />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'prompt'}>
|
||||||
|
<PromptInput />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'processButtons'}>
|
||||||
|
<ProcessButtons />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem area={'currentImage'}>
|
||||||
|
<CurrentImageDisplay />
|
||||||
|
</GridItem>
|
||||||
|
<GridItem pr="2" area={'imageRoll'} overflowY="scroll">
|
||||||
|
<ImageGallery />
|
||||||
|
</GridItem>
|
||||||
|
</Grid>
|
||||||
|
<LogViewer />
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Loading />
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default App;
|
@ -2,52 +2,52 @@
|
|||||||
|
|
||||||
// Valid samplers
|
// Valid samplers
|
||||||
export const SAMPLERS: Array<string> = [
|
export const SAMPLERS: Array<string> = [
|
||||||
'ddim',
|
'ddim',
|
||||||
'plms',
|
'plms',
|
||||||
'k_lms',
|
'k_lms',
|
||||||
'k_dpm_2',
|
'k_dpm_2',
|
||||||
'k_dpm_2_a',
|
'k_dpm_2_a',
|
||||||
'k_euler',
|
'k_euler',
|
||||||
'k_euler_a',
|
'k_euler_a',
|
||||||
'k_heun',
|
'k_heun',
|
||||||
];
|
];
|
||||||
|
|
||||||
// Valid image widths
|
// Valid image widths
|
||||||
export const WIDTHS: Array<number> = [
|
export const WIDTHS: Array<number> = [
|
||||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||||
1024,
|
1024,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Valid image heights
|
// Valid image heights
|
||||||
export const HEIGHTS: Array<number> = [
|
export const HEIGHTS: Array<number> = [
|
||||||
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
|
||||||
1024,
|
1024,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Valid upscaling levels
|
// Valid upscaling levels
|
||||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||||
{ key: '2x', value: 2 },
|
{ key: '2x', value: 2 },
|
||||||
{ key: '4x', value: 4 },
|
{ key: '4x', value: 4 },
|
||||||
];
|
];
|
||||||
|
|
||||||
// Internal to human-readable parameters
|
// Internal to human-readable parameters
|
||||||
export const PARAMETERS: { [key: string]: string } = {
|
export const PARAMETERS: { [key: string]: string } = {
|
||||||
prompt: 'Prompt',
|
prompt: 'Prompt',
|
||||||
iterations: 'Iterations',
|
iterations: 'Iterations',
|
||||||
steps: 'Steps',
|
steps: 'Steps',
|
||||||
cfgScale: 'CFG Scale',
|
cfgScale: 'CFG Scale',
|
||||||
height: 'Height',
|
height: 'Height',
|
||||||
width: 'Width',
|
width: 'Width',
|
||||||
sampler: 'Sampler',
|
sampler: 'Sampler',
|
||||||
seed: 'Seed',
|
seed: 'Seed',
|
||||||
img2imgStrength: 'img2img Strength',
|
img2imgStrength: 'img2img Strength',
|
||||||
gfpganStrength: 'GFPGAN Strength',
|
gfpganStrength: 'GFPGAN Strength',
|
||||||
upscalingLevel: 'Upscaling Level',
|
upscalingLevel: 'Upscaling Level',
|
||||||
upscalingStrength: 'Upscaling Strength',
|
upscalingStrength: 'Upscaling Strength',
|
||||||
initialImagePath: 'Initial Image',
|
initialImagePath: 'Initial Image',
|
||||||
maskPath: 'Initial Image Mask',
|
maskPath: 'Initial Image Mask',
|
||||||
shouldFitToWidthHeight: 'Fit Initial Image',
|
shouldFitToWidthHeight: 'Fit Initial Image',
|
||||||
seamless: 'Seamless Tiling',
|
seamless: 'Seamless Tiling',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const NUMPY_RAND_MIN = 0;
|
export const NUMPY_RAND_MIN = 0;
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
import { useDispatch, useSelector } from 'react-redux';
|
|
||||||
import type { TypedUseSelectorHook } from 'react-redux';
|
|
||||||
import type { RootState, AppDispatch } from './store';
|
|
||||||
|
|
||||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
|
||||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
|
||||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
|
170
frontend/src/app/invokeai.d.ts
vendored
Normal file
170
frontend/src/app/invokeai.d.ts
vendored
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
/**
|
||||||
|
* Types for images, the things they are made of, and the things
|
||||||
|
* they make up.
|
||||||
|
*
|
||||||
|
* Generated images are txt2img and img2img images. They may have
|
||||||
|
* had additional postprocessing done on them when they were first
|
||||||
|
* generated.
|
||||||
|
*
|
||||||
|
* Postprocessed images are images which were not generated here
|
||||||
|
* but only postprocessed by the app. They only get postprocessing
|
||||||
|
* metadata and have a different image type, e.g. 'esrgan' or
|
||||||
|
* 'gfpgan'.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TODO:
|
||||||
|
* Once an image has been generated, if it is postprocessed again,
|
||||||
|
* additional postprocessing steps are added to its postprocessing
|
||||||
|
* array.
|
||||||
|
*
|
||||||
|
* TODO: Better documentation of types.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export declare type PromptItem = {
|
||||||
|
prompt: string;
|
||||||
|
weight: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type Prompt = Array<PromptItem>;
|
||||||
|
|
||||||
|
export declare type SeedWeightPair = {
|
||||||
|
seed: number;
|
||||||
|
weight: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SeedWeights = Array<SeedWeightPair>;
|
||||||
|
|
||||||
|
// All generated images contain these metadata.
|
||||||
|
export declare type CommonGeneratedImageMetadata = {
|
||||||
|
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
|
||||||
|
sampler:
|
||||||
|
| 'ddim'
|
||||||
|
| 'k_dpm_2_a'
|
||||||
|
| 'k_dpm_2'
|
||||||
|
| 'k_euler_a'
|
||||||
|
| 'k_euler'
|
||||||
|
| 'k_heun'
|
||||||
|
| 'k_lms'
|
||||||
|
| 'plms';
|
||||||
|
prompt: Prompt;
|
||||||
|
seed: number;
|
||||||
|
variations: SeedWeights;
|
||||||
|
steps: number;
|
||||||
|
cfg_scale: number;
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
seamless: boolean;
|
||||||
|
extra: null | Record<string, never>; // Pending development of RFC #266
|
||||||
|
};
|
||||||
|
|
||||||
|
// txt2img and img2img images have some unique attributes.
|
||||||
|
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
|
||||||
|
type: 'txt2img';
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
|
||||||
|
type: 'img2img';
|
||||||
|
orig_hash: string;
|
||||||
|
strength: number;
|
||||||
|
fit: boolean;
|
||||||
|
init_image_path: string;
|
||||||
|
mask_image_path?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Superset of generated image metadata types.
|
||||||
|
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
|
||||||
|
|
||||||
|
// All post processed images contain these metadata.
|
||||||
|
export declare type CommonPostProcessedImageMetadata = {
|
||||||
|
orig_path: string;
|
||||||
|
orig_hash: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// esrgan and gfpgan images have some unique attributes.
|
||||||
|
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
|
||||||
|
type: 'esrgan';
|
||||||
|
scale: 2 | 4;
|
||||||
|
strength: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type GFPGANMetadata = CommonPostProcessedImageMetadata & {
|
||||||
|
type: 'gfpgan';
|
||||||
|
strength: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Superset of all postprocessed image metadata types..
|
||||||
|
export declare type PostProcessedImageMetadata =
|
||||||
|
| ESRGANMetadata
|
||||||
|
| GFPGANMetadata;
|
||||||
|
|
||||||
|
// Metadata includes the system config and image metadata.
|
||||||
|
export declare type Metadata = SystemConfig & {
|
||||||
|
image: GeneratedImageMetadata | PostProcessedImageMetadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An Image has a UUID, url (path?) and Metadata.
|
||||||
|
export declare type Image = {
|
||||||
|
uuid: string;
|
||||||
|
url: string;
|
||||||
|
metadata: Metadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
// GalleryImages is an array of Image.
|
||||||
|
export declare type GalleryImages = {
|
||||||
|
images: Array<Image>;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Types related to the system status.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// This represents the processing status of the backend.
|
||||||
|
export declare type SystemStatus = {
|
||||||
|
isProcessing: boolean;
|
||||||
|
currentStep: number;
|
||||||
|
totalSteps: number;
|
||||||
|
currentIteration: number;
|
||||||
|
totalIterations: number;
|
||||||
|
currentStatus: string;
|
||||||
|
currentStatusHasSteps: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type SystemConfig = {
|
||||||
|
model: string;
|
||||||
|
model_id: string;
|
||||||
|
model_hash: string;
|
||||||
|
app_id: string;
|
||||||
|
app_version: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These types type data received from the server via socketio.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export declare type SystemStatusResponse = SystemStatus;
|
||||||
|
|
||||||
|
export declare type SystemConfigResponse = SystemConfig;
|
||||||
|
|
||||||
|
export declare type ImageResultResponse = {
|
||||||
|
url: string;
|
||||||
|
metadata: Metadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ErrorResponse = {
|
||||||
|
message: string;
|
||||||
|
additionalData?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type GalleryImagesResponse = {
|
||||||
|
images: Array<{ url: string; metadata: Metadata }>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ImageUrlAndUuidResponse = {
|
||||||
|
uuid: string;
|
||||||
|
url: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export declare type ImageUrlResponse = {
|
||||||
|
url: string;
|
||||||
|
};
|
@ -1,182 +0,0 @@
|
|||||||
import { SDState } from '../features/sd/sdSlice';
|
|
||||||
import randomInt from '../features/sd/util/randomInt';
|
|
||||||
import {
|
|
||||||
seedWeightsToString,
|
|
||||||
stringToSeedWeights,
|
|
||||||
} from '../features/sd/util/seedWeightPairs';
|
|
||||||
import { SystemState } from '../features/system/systemSlice';
|
|
||||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
|
|
||||||
|
|
||||||
/*
|
|
||||||
These functions translate frontend state into parameters
|
|
||||||
suitable for consumption by the backend, and vice-versa.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export const frontendToBackendParameters = (
|
|
||||||
sdState: SDState,
|
|
||||||
systemState: SystemState
|
|
||||||
): { [key: string]: any } => {
|
|
||||||
const {
|
|
||||||
prompt,
|
|
||||||
iterations,
|
|
||||||
steps,
|
|
||||||
cfgScale,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
sampler,
|
|
||||||
seed,
|
|
||||||
seamless,
|
|
||||||
shouldUseInitImage,
|
|
||||||
img2imgStrength,
|
|
||||||
initialImagePath,
|
|
||||||
maskPath,
|
|
||||||
shouldFitToWidthHeight,
|
|
||||||
shouldGenerateVariations,
|
|
||||||
variantAmount,
|
|
||||||
seedWeights,
|
|
||||||
shouldRunESRGAN,
|
|
||||||
upscalingLevel,
|
|
||||||
upscalingStrength,
|
|
||||||
shouldRunGFPGAN,
|
|
||||||
gfpganStrength,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = sdState;
|
|
||||||
|
|
||||||
const { shouldDisplayInProgress } = systemState;
|
|
||||||
|
|
||||||
const generationParameters: { [k: string]: any } = {
|
|
||||||
prompt,
|
|
||||||
iterations,
|
|
||||||
steps,
|
|
||||||
cfg_scale: cfgScale,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
sampler_name: sampler,
|
|
||||||
seed,
|
|
||||||
seamless,
|
|
||||||
progress_images: shouldDisplayInProgress,
|
|
||||||
};
|
|
||||||
|
|
||||||
generationParameters.seed = shouldRandomizeSeed
|
|
||||||
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
|
|
||||||
: seed;
|
|
||||||
|
|
||||||
if (shouldUseInitImage) {
|
|
||||||
generationParameters.init_img = initialImagePath;
|
|
||||||
generationParameters.strength = img2imgStrength;
|
|
||||||
generationParameters.fit = shouldFitToWidthHeight;
|
|
||||||
if (maskPath) {
|
|
||||||
generationParameters.init_mask = maskPath;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldGenerateVariations) {
|
|
||||||
generationParameters.variation_amount = variantAmount;
|
|
||||||
if (seedWeights) {
|
|
||||||
generationParameters.with_variations =
|
|
||||||
stringToSeedWeights(seedWeights);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
generationParameters.variation_amount = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
let esrganParameters: false | { [k: string]: any } = false;
|
|
||||||
let gfpganParameters: false | { [k: string]: any } = false;
|
|
||||||
|
|
||||||
if (shouldRunESRGAN) {
|
|
||||||
esrganParameters = {
|
|
||||||
level: upscalingLevel,
|
|
||||||
strength: upscalingStrength,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldRunGFPGAN) {
|
|
||||||
gfpganParameters = {
|
|
||||||
strength: gfpganStrength,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
generationParameters,
|
|
||||||
esrganParameters,
|
|
||||||
gfpganParameters,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export const backendToFrontendParameters = (parameters: {
|
|
||||||
[key: string]: any;
|
|
||||||
}) => {
|
|
||||||
const {
|
|
||||||
prompt,
|
|
||||||
iterations,
|
|
||||||
steps,
|
|
||||||
cfg_scale,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
sampler_name,
|
|
||||||
seed,
|
|
||||||
seamless,
|
|
||||||
progress_images,
|
|
||||||
variation_amount,
|
|
||||||
with_variations,
|
|
||||||
gfpgan_strength,
|
|
||||||
upscale,
|
|
||||||
init_img,
|
|
||||||
init_mask,
|
|
||||||
strength,
|
|
||||||
} = parameters;
|
|
||||||
|
|
||||||
const sd: { [key: string]: any } = {
|
|
||||||
shouldDisplayInProgress: progress_images,
|
|
||||||
// init
|
|
||||||
shouldGenerateVariations: false,
|
|
||||||
shouldRunESRGAN: false,
|
|
||||||
shouldRunGFPGAN: false,
|
|
||||||
initialImagePath: '',
|
|
||||||
maskPath: '',
|
|
||||||
};
|
|
||||||
|
|
||||||
if (variation_amount > 0) {
|
|
||||||
sd.shouldGenerateVariations = true;
|
|
||||||
sd.variantAmount = variation_amount;
|
|
||||||
if (with_variations) {
|
|
||||||
sd.seedWeights = seedWeightsToString(with_variations);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gfpgan_strength > 0) {
|
|
||||||
sd.shouldRunGFPGAN = true;
|
|
||||||
sd.gfpganStrength = gfpgan_strength;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (upscale) {
|
|
||||||
sd.shouldRunESRGAN = true;
|
|
||||||
sd.upscalingLevel = upscale[0];
|
|
||||||
sd.upscalingStrength = upscale[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (init_img) {
|
|
||||||
sd.shouldUseInitImage = true
|
|
||||||
sd.initialImagePath = init_img;
|
|
||||||
sd.strength = strength;
|
|
||||||
if (init_mask) {
|
|
||||||
sd.maskPath = init_mask;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we had a prompt, add all the metadata, but if we don't have a prompt,
|
|
||||||
// we must have only done ESRGAN or GFPGAN so do not add that metadata
|
|
||||||
if (prompt) {
|
|
||||||
sd.prompt = prompt;
|
|
||||||
sd.iterations = iterations;
|
|
||||||
sd.steps = steps;
|
|
||||||
sd.cfgScale = cfg_scale;
|
|
||||||
sd.height = height;
|
|
||||||
sd.width = width;
|
|
||||||
sd.sampler = sampler_name;
|
|
||||||
sd.seed = seed;
|
|
||||||
sd.seamless = seamless;
|
|
||||||
}
|
|
||||||
|
|
||||||
return sd;
|
|
||||||
};
|
|
@ -1,393 +0,0 @@
|
|||||||
import { createAction, Middleware } from '@reduxjs/toolkit';
|
|
||||||
import { io } from 'socket.io-client';
|
|
||||||
import {
|
|
||||||
addImage,
|
|
||||||
clearIntermediateImage,
|
|
||||||
removeImage,
|
|
||||||
SDImage,
|
|
||||||
SDMetadata,
|
|
||||||
setGalleryImages,
|
|
||||||
setIntermediateImage,
|
|
||||||
} from '../features/gallery/gallerySlice';
|
|
||||||
import {
|
|
||||||
addLogEntry,
|
|
||||||
setCurrentStep,
|
|
||||||
setIsConnected,
|
|
||||||
setIsProcessing,
|
|
||||||
} from '../features/system/systemSlice';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
|
|
||||||
import {
|
|
||||||
backendToFrontendParameters,
|
|
||||||
frontendToBackendParameters,
|
|
||||||
} from './parameterTranslation';
|
|
||||||
|
|
||||||
export interface SocketIOResponse {
|
|
||||||
status: 'OK' | 'ERROR';
|
|
||||||
message?: string;
|
|
||||||
data?: any;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const socketioMiddleware = () => {
|
|
||||||
const { hostname, port } = new URL(window.location.href);
|
|
||||||
|
|
||||||
const socketio = io(`http://${hostname}:9090`);
|
|
||||||
|
|
||||||
let areListenersSet = false;
|
|
||||||
|
|
||||||
const middleware: Middleware = (store) => (next) => (action) => {
|
|
||||||
const { dispatch, getState } = store;
|
|
||||||
if (!areListenersSet) {
|
|
||||||
// CONNECT
|
|
||||||
socketio.on('connect', () => {
|
|
||||||
try {
|
|
||||||
dispatch(setIsConnected(true));
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// DISCONNECT
|
|
||||||
socketio.on('disconnect', () => {
|
|
||||||
try {
|
|
||||||
dispatch(setIsConnected(false));
|
|
||||||
dispatch(setIsProcessing(false));
|
|
||||||
dispatch(addLogEntry(`Disconnected from server`));
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// PROCESSING RESULT
|
|
||||||
socketio.on(
|
|
||||||
'result',
|
|
||||||
(data: {
|
|
||||||
url: string;
|
|
||||||
type: 'generation' | 'esrgan' | 'gfpgan';
|
|
||||||
uuid?: string;
|
|
||||||
metadata: { [key: string]: any };
|
|
||||||
}) => {
|
|
||||||
try {
|
|
||||||
const newUuid = uuidv4();
|
|
||||||
const { type, url, uuid, metadata } = data;
|
|
||||||
switch (type) {
|
|
||||||
case 'generation': {
|
|
||||||
const translatedMetadata =
|
|
||||||
backendToFrontendParameters(metadata);
|
|
||||||
dispatch(
|
|
||||||
addImage({
|
|
||||||
uuid: newUuid,
|
|
||||||
url,
|
|
||||||
metadata: translatedMetadata,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(`Image generated: ${url}`)
|
|
||||||
);
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'esrgan': {
|
|
||||||
const originalImage =
|
|
||||||
getState().gallery.images.find(
|
|
||||||
(i: SDImage) => i.uuid === uuid
|
|
||||||
);
|
|
||||||
const newMetadata = {
|
|
||||||
...originalImage.metadata,
|
|
||||||
};
|
|
||||||
newMetadata.shouldRunESRGAN = true;
|
|
||||||
newMetadata.upscalingLevel =
|
|
||||||
metadata.upscale[0];
|
|
||||||
newMetadata.upscalingStrength =
|
|
||||||
metadata.upscale[1];
|
|
||||||
dispatch(
|
|
||||||
addImage({
|
|
||||||
uuid: newUuid,
|
|
||||||
url,
|
|
||||||
metadata: newMetadata,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(`ESRGAN upscaled: ${url}`)
|
|
||||||
);
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'gfpgan': {
|
|
||||||
const originalImage =
|
|
||||||
getState().gallery.images.find(
|
|
||||||
(i: SDImage) => i.uuid === uuid
|
|
||||||
);
|
|
||||||
const newMetadata = {
|
|
||||||
...originalImage.metadata,
|
|
||||||
};
|
|
||||||
newMetadata.shouldRunGFPGAN = true;
|
|
||||||
newMetadata.gfpganStrength =
|
|
||||||
metadata.gfpgan_strength;
|
|
||||||
dispatch(
|
|
||||||
addImage({
|
|
||||||
uuid: newUuid,
|
|
||||||
url,
|
|
||||||
metadata: newMetadata,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(`GFPGAN fixed faces: ${url}`)
|
|
||||||
);
|
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dispatch(setIsProcessing(false));
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// PROGRESS UPDATE
|
|
||||||
socketio.on('progress', (data: { step: number }) => {
|
|
||||||
try {
|
|
||||||
dispatch(setIsProcessing(true));
|
|
||||||
dispatch(setCurrentStep(data.step));
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// INTERMEDIATE IMAGE
|
|
||||||
socketio.on(
|
|
||||||
'intermediateResult',
|
|
||||||
(data: { url: string; metadata: SDMetadata }) => {
|
|
||||||
try {
|
|
||||||
const uuid = uuidv4();
|
|
||||||
const { url, metadata } = data;
|
|
||||||
dispatch(
|
|
||||||
setIntermediateImage({
|
|
||||||
uuid,
|
|
||||||
url,
|
|
||||||
metadata,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(`Intermediate image generated: ${url}`)
|
|
||||||
);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// ERROR FROM BACKEND
|
|
||||||
socketio.on('error', (message) => {
|
|
||||||
try {
|
|
||||||
dispatch(addLogEntry(`Server error: ${message}`));
|
|
||||||
dispatch(setIsProcessing(false));
|
|
||||||
dispatch(clearIntermediateImage());
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
areListenersSet = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// HANDLE ACTIONS
|
|
||||||
|
|
||||||
switch (action.type) {
|
|
||||||
// GENERATE IMAGE
|
|
||||||
case 'socketio/generateImage': {
|
|
||||||
dispatch(setIsProcessing(true));
|
|
||||||
dispatch(setCurrentStep(-1));
|
|
||||||
|
|
||||||
const {
|
|
||||||
generationParameters,
|
|
||||||
esrganParameters,
|
|
||||||
gfpganParameters,
|
|
||||||
} = frontendToBackendParameters(
|
|
||||||
getState().sd,
|
|
||||||
getState().system
|
|
||||||
);
|
|
||||||
|
|
||||||
socketio.emit(
|
|
||||||
'generateImage',
|
|
||||||
generationParameters,
|
|
||||||
esrganParameters,
|
|
||||||
gfpganParameters
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`Image generation requested: ${JSON.stringify({
|
|
||||||
...generationParameters,
|
|
||||||
...esrganParameters,
|
|
||||||
...gfpganParameters,
|
|
||||||
})}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// RUN ESRGAN (UPSCALING)
|
|
||||||
case 'socketio/runESRGAN': {
|
|
||||||
const imageToProcess = action.payload;
|
|
||||||
dispatch(setIsProcessing(true));
|
|
||||||
dispatch(setCurrentStep(-1));
|
|
||||||
const { upscalingLevel, upscalingStrength } = getState().sd;
|
|
||||||
const esrganParameters = {
|
|
||||||
upscale: [upscalingLevel, upscalingStrength],
|
|
||||||
};
|
|
||||||
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`ESRGAN upscale requested: ${JSON.stringify({
|
|
||||||
file: imageToProcess.url,
|
|
||||||
...esrganParameters,
|
|
||||||
})}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// RUN GFPGAN (FIX FACES)
|
|
||||||
case 'socketio/runGFPGAN': {
|
|
||||||
const imageToProcess = action.payload;
|
|
||||||
dispatch(setIsProcessing(true));
|
|
||||||
dispatch(setCurrentStep(-1));
|
|
||||||
const { gfpganStrength } = getState().sd;
|
|
||||||
|
|
||||||
const gfpganParameters = {
|
|
||||||
gfpgan_strength: gfpganStrength,
|
|
||||||
};
|
|
||||||
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`GFPGAN fix faces requested: ${JSON.stringify({
|
|
||||||
file: imageToProcess.url,
|
|
||||||
...gfpganParameters,
|
|
||||||
})}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// DELETE IMAGE
|
|
||||||
case 'socketio/deleteImage': {
|
|
||||||
const imageToDelete = action.payload;
|
|
||||||
const { url } = imageToDelete;
|
|
||||||
socketio.emit(
|
|
||||||
'deleteImage',
|
|
||||||
url,
|
|
||||||
(response: SocketIOResponse) => {
|
|
||||||
if (response.status === 'OK') {
|
|
||||||
dispatch(removeImage(imageToDelete));
|
|
||||||
dispatch(addLogEntry(`Image deleted: ${url}`));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// GET ALL IMAGES FOR GALLERY
|
|
||||||
case 'socketio/requestAllImages': {
|
|
||||||
socketio.emit(
|
|
||||||
'requestAllImages',
|
|
||||||
(response: SocketIOResponse) => {
|
|
||||||
dispatch(setGalleryImages(response.data));
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(`Loaded ${response.data.length} images`)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// CANCEL PROCESSING
|
|
||||||
case 'socketio/cancelProcessing': {
|
|
||||||
socketio.emit('cancel', (response: SocketIOResponse) => {
|
|
||||||
const { intermediateImage } = getState().gallery;
|
|
||||||
if (response.status === 'OK') {
|
|
||||||
dispatch(setIsProcessing(false));
|
|
||||||
if (intermediateImage) {
|
|
||||||
dispatch(addImage(intermediateImage));
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`Intermediate image saved: ${intermediateImage.url}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(clearIntermediateImage());
|
|
||||||
}
|
|
||||||
dispatch(addLogEntry(`Processing canceled`));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// UPLOAD INITIAL IMAGE
|
|
||||||
case 'socketio/uploadInitialImage': {
|
|
||||||
const file = action.payload;
|
|
||||||
|
|
||||||
socketio.emit(
|
|
||||||
'uploadInitialImage',
|
|
||||||
file,
|
|
||||||
file.name,
|
|
||||||
(response: SocketIOResponse) => {
|
|
||||||
if (response.status === 'OK') {
|
|
||||||
dispatch(setInitialImagePath(response.data));
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`Initial image uploaded: ${response.data}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// UPLOAD MASK IMAGE
|
|
||||||
case 'socketio/uploadMaskImage': {
|
|
||||||
const file = action.payload;
|
|
||||||
|
|
||||||
socketio.emit(
|
|
||||||
'uploadMaskImage',
|
|
||||||
file,
|
|
||||||
file.name,
|
|
||||||
(response: SocketIOResponse) => {
|
|
||||||
if (response.status === 'OK') {
|
|
||||||
dispatch(setMaskPath(response.data));
|
|
||||||
dispatch(
|
|
||||||
addLogEntry(
|
|
||||||
`Mask image uploaded: ${response.data}`
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
next(action);
|
|
||||||
};
|
|
||||||
|
|
||||||
return middleware;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Actions to be used by app
|
|
||||||
|
|
||||||
export const generateImage = createAction<undefined>('socketio/generateImage');
|
|
||||||
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
|
|
||||||
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
|
|
||||||
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
|
|
||||||
export const requestAllImages = createAction<undefined>(
|
|
||||||
'socketio/requestAllImages'
|
|
||||||
);
|
|
||||||
export const cancelProcessing = createAction<undefined>(
|
|
||||||
'socketio/cancelProcessing'
|
|
||||||
);
|
|
||||||
export const uploadInitialImage = createAction<File>(
|
|
||||||
'socketio/uploadInitialImage'
|
|
||||||
);
|
|
||||||
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
|
|
26
frontend/src/app/socketio/actions.ts
Normal file
26
frontend/src/app/socketio/actions.ts
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import * as InvokeAI from '../invokeai';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We can't use redux-toolkit's createSlice() to make these actions,
|
||||||
|
* because they have no associated reducer. They only exist to dispatch
|
||||||
|
* requests to the server via socketio. These actions will be handled
|
||||||
|
* by the middleware.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export const generateImage = createAction<undefined>('socketio/generateImage');
|
||||||
|
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
||||||
|
export const runGFPGAN = createAction<InvokeAI.Image>('socketio/runGFPGAN');
|
||||||
|
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
||||||
|
export const requestAllImages = createAction<undefined>(
|
||||||
|
'socketio/requestAllImages'
|
||||||
|
);
|
||||||
|
export const cancelProcessing = createAction<undefined>(
|
||||||
|
'socketio/cancelProcessing'
|
||||||
|
);
|
||||||
|
export const uploadInitialImage = createAction<File>(
|
||||||
|
'socketio/uploadInitialImage'
|
||||||
|
);
|
||||||
|
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
|
||||||
|
|
||||||
|
export const requestSystemConfig = createAction<undefined>('socketio/requestSystemConfig');
|
104
frontend/src/app/socketio/emitters.ts
Normal file
104
frontend/src/app/socketio/emitters.ts
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
|
import dateFormat from 'dateformat';
|
||||||
|
import { Socket } from 'socket.io-client';
|
||||||
|
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
|
||||||
|
import {
|
||||||
|
addLogEntry,
|
||||||
|
setIsProcessing,
|
||||||
|
} from '../../features/system/systemSlice';
|
||||||
|
import * as InvokeAI from '../invokeai';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an object containing all functions which use `socketio.emit()`.
|
||||||
|
* i.e. those which make server requests.
|
||||||
|
*/
|
||||||
|
const makeSocketIOEmitters = (
|
||||||
|
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
|
||||||
|
socketio: Socket
|
||||||
|
) => {
|
||||||
|
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||||
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
|
return {
|
||||||
|
emitGenerateImage: () => {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
|
const { generationParameters, esrganParameters, gfpganParameters } =
|
||||||
|
frontendToBackendParameters(getState().options, getState().system);
|
||||||
|
|
||||||
|
socketio.emit(
|
||||||
|
'generateImage',
|
||||||
|
generationParameters,
|
||||||
|
esrganParameters,
|
||||||
|
gfpganParameters
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Image generation requested: ${JSON.stringify({
|
||||||
|
...generationParameters,
|
||||||
|
...esrganParameters,
|
||||||
|
...gfpganParameters,
|
||||||
|
})}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
const { upscalingLevel, upscalingStrength } = getState().options;
|
||||||
|
const esrganParameters = {
|
||||||
|
upscale: [upscalingLevel, upscalingStrength],
|
||||||
|
};
|
||||||
|
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||||
|
file: imageToProcess.url,
|
||||||
|
...esrganParameters,
|
||||||
|
})}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
emitRunGFPGAN: (imageToProcess: InvokeAI.Image) => {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
const { gfpganStrength } = getState().options;
|
||||||
|
|
||||||
|
const gfpganParameters = {
|
||||||
|
gfpgan_strength: gfpganStrength,
|
||||||
|
};
|
||||||
|
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `GFPGAN fix faces requested: ${JSON.stringify({
|
||||||
|
file: imageToProcess.url,
|
||||||
|
...gfpganParameters,
|
||||||
|
})}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
||||||
|
const { url, uuid } = imageToDelete;
|
||||||
|
socketio.emit('deleteImage', url, uuid);
|
||||||
|
},
|
||||||
|
emitRequestAllImages: () => {
|
||||||
|
socketio.emit('requestAllImages');
|
||||||
|
},
|
||||||
|
emitCancelProcessing: () => {
|
||||||
|
socketio.emit('cancel');
|
||||||
|
},
|
||||||
|
emitUploadInitialImage: (file: File) => {
|
||||||
|
socketio.emit('uploadInitialImage', file, file.name);
|
||||||
|
},
|
||||||
|
emitUploadMaskImage: (file: File) => {
|
||||||
|
socketio.emit('uploadMaskImage', file, file.name);
|
||||||
|
},
|
||||||
|
emitRequestSystemConfig: () => {
|
||||||
|
socketio.emit('requestSystemConfig')
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default makeSocketIOEmitters;
|
300
frontend/src/app/socketio/listeners.ts
Normal file
300
frontend/src/app/socketio/listeners.ts
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import dateFormat from 'dateformat';
|
||||||
|
|
||||||
|
import * as InvokeAI from '../invokeai';
|
||||||
|
|
||||||
|
import {
|
||||||
|
addLogEntry,
|
||||||
|
setIsConnected,
|
||||||
|
setIsProcessing,
|
||||||
|
setSystemStatus,
|
||||||
|
setCurrentStatus,
|
||||||
|
setSystemConfig,
|
||||||
|
} from '../../features/system/systemSlice';
|
||||||
|
|
||||||
|
import {
|
||||||
|
addImage,
|
||||||
|
clearIntermediateImage,
|
||||||
|
removeImage,
|
||||||
|
setGalleryImages,
|
||||||
|
setIntermediateImage,
|
||||||
|
} from '../../features/gallery/gallerySlice';
|
||||||
|
|
||||||
|
import {
|
||||||
|
setInitialImagePath,
|
||||||
|
setMaskPath,
|
||||||
|
} from '../../features/options/optionsSlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an object containing listener callbacks for socketio events.
|
||||||
|
* TODO: This file is large, but simple. Should it be split up further?
|
||||||
|
*/
|
||||||
|
const makeSocketIOListeners = (
|
||||||
|
store: MiddlewareAPI<Dispatch<AnyAction>, any>
|
||||||
|
) => {
|
||||||
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
|
return {
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'connect' event.
|
||||||
|
*/
|
||||||
|
onConnect: () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(true));
|
||||||
|
dispatch(setCurrentStatus('Connected'));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'disconnect' event.
|
||||||
|
*/
|
||||||
|
onDisconnect: () => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsConnected(false));
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
dispatch(setCurrentStatus('Disconnected'));
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Disconnected from server`,
|
||||||
|
level: 'warning',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'generationResult' event.
|
||||||
|
*/
|
||||||
|
onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
|
try {
|
||||||
|
const { url, metadata } = data;
|
||||||
|
const newUuid = uuidv4();
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: newUuid,
|
||||||
|
url,
|
||||||
|
metadata: metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Image generated: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'intermediateResult' event.
|
||||||
|
*/
|
||||||
|
onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
|
try {
|
||||||
|
const uuid = uuidv4();
|
||||||
|
const { url, metadata } = data;
|
||||||
|
dispatch(
|
||||||
|
setIntermediateImage({
|
||||||
|
uuid,
|
||||||
|
url,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Intermediate image generated: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive an 'esrganResult' event.
|
||||||
|
*/
|
||||||
|
onESRGANResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
|
try {
|
||||||
|
const { url, metadata } = data;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: uuidv4(),
|
||||||
|
url,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Upscaled: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'gfpganResult' event.
|
||||||
|
*/
|
||||||
|
onGFPGANResult: (data: InvokeAI.ImageResultResponse) => {
|
||||||
|
try {
|
||||||
|
const { url, metadata } = data;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addImage({
|
||||||
|
uuid: uuidv4(),
|
||||||
|
url,
|
||||||
|
metadata,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Fixed faces: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
|
* TODO: Add additional progress phases
|
||||||
|
*/
|
||||||
|
onProgressUpdate: (data: InvokeAI.SystemStatus) => {
|
||||||
|
try {
|
||||||
|
dispatch(setIsProcessing(true));
|
||||||
|
dispatch(setSystemStatus(data));
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'progressUpdate' event.
|
||||||
|
*/
|
||||||
|
onError: (data: InvokeAI.ErrorResponse) => {
|
||||||
|
const { message, additionalData } = data;
|
||||||
|
|
||||||
|
if (additionalData) {
|
||||||
|
// TODO: handle more data than short message
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Server error: ${message}`,
|
||||||
|
level: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
dispatch(clearIntermediateImage());
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'galleryImages' event.
|
||||||
|
*/
|
||||||
|
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
|
||||||
|
const { images } = data;
|
||||||
|
const preparedImages = images.map((image): InvokeAI.Image => {
|
||||||
|
const { url, metadata } = image;
|
||||||
|
return {
|
||||||
|
uuid: uuidv4(),
|
||||||
|
url,
|
||||||
|
metadata,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
dispatch(setGalleryImages(preparedImages));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Loaded ${images.length} images`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'processingCanceled' event.
|
||||||
|
*/
|
||||||
|
onProcessingCanceled: () => {
|
||||||
|
dispatch(setIsProcessing(false));
|
||||||
|
|
||||||
|
const { intermediateImage } = getState().gallery;
|
||||||
|
|
||||||
|
if (intermediateImage) {
|
||||||
|
dispatch(addImage(intermediateImage));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Intermediate image saved: ${intermediateImage.url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(clearIntermediateImage());
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Processing canceled`,
|
||||||
|
level: 'warning',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'imageDeleted' event.
|
||||||
|
*/
|
||||||
|
onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => {
|
||||||
|
const { url, uuid } = data;
|
||||||
|
dispatch(removeImage(uuid));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Image deleted: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'initialImageUploaded' event.
|
||||||
|
*/
|
||||||
|
onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
|
||||||
|
const { url } = data;
|
||||||
|
dispatch(setInitialImagePath(url));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Initial image uploaded: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
/**
|
||||||
|
* Callback to run when we receive a 'maskImageUploaded' event.
|
||||||
|
*/
|
||||||
|
onMaskImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
|
||||||
|
const { url } = data;
|
||||||
|
dispatch(setMaskPath(url));
|
||||||
|
dispatch(
|
||||||
|
addLogEntry({
|
||||||
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
|
message: `Mask image uploaded: ${url}`,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
onSystemConfig: (data: InvokeAI.SystemConfig) => {
|
||||||
|
dispatch(setSystemConfig(data));
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default makeSocketIOListeners;
|
173
frontend/src/app/socketio/middleware.ts
Normal file
173
frontend/src/app/socketio/middleware.ts
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
|
import { io } from 'socket.io-client';
|
||||||
|
|
||||||
|
import makeSocketIOListeners from './listeners';
|
||||||
|
import makeSocketIOEmitters from './emitters';
|
||||||
|
|
||||||
|
import * as InvokeAI from '../invokeai';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a socketio middleware to handle communication with server.
|
||||||
|
*
|
||||||
|
* Special `socketio/actionName` actions are created in actions.ts and
|
||||||
|
* exported for use by the application, which treats them like any old
|
||||||
|
* action, using `dispatch` to dispatch them.
|
||||||
|
*
|
||||||
|
* These actions are intercepted here, where `socketio.emit()` calls are
|
||||||
|
* made on their behalf - see `emitters.ts`. The emitter functions
|
||||||
|
* are the outbound communication to the server.
|
||||||
|
*
|
||||||
|
* Listeners are also established here - see `listeners.ts`. The listener
|
||||||
|
* functions receive communication from the server and usually dispatch
|
||||||
|
* some new action to handle whatever data was sent from the server.
|
||||||
|
*/
|
||||||
|
export const socketioMiddleware = () => {
|
||||||
|
const { hostname, port } = new URL(window.location.href);
|
||||||
|
|
||||||
|
const socketio = io(`http://${hostname}:9090`);
|
||||||
|
|
||||||
|
let areListenersSet = false;
|
||||||
|
|
||||||
|
const middleware: Middleware = (store) => (next) => (action) => {
|
||||||
|
const {
|
||||||
|
onConnect,
|
||||||
|
onDisconnect,
|
||||||
|
onError,
|
||||||
|
onESRGANResult,
|
||||||
|
onGFPGANResult,
|
||||||
|
onGenerationResult,
|
||||||
|
onIntermediateResult,
|
||||||
|
onProgressUpdate,
|
||||||
|
onGalleryImages,
|
||||||
|
onProcessingCanceled,
|
||||||
|
onImageDeleted,
|
||||||
|
onInitialImageUploaded,
|
||||||
|
onMaskImageUploaded,
|
||||||
|
onSystemConfig,
|
||||||
|
} = makeSocketIOListeners(store);
|
||||||
|
|
||||||
|
const {
|
||||||
|
emitGenerateImage,
|
||||||
|
emitRunESRGAN,
|
||||||
|
emitRunGFPGAN,
|
||||||
|
emitDeleteImage,
|
||||||
|
emitRequestAllImages,
|
||||||
|
emitCancelProcessing,
|
||||||
|
emitUploadInitialImage,
|
||||||
|
emitUploadMaskImage,
|
||||||
|
emitRequestSystemConfig,
|
||||||
|
} = makeSocketIOEmitters(store, socketio);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If this is the first time the middleware has been called (e.g. during store setup),
|
||||||
|
* initialize all our socket.io listeners.
|
||||||
|
*/
|
||||||
|
if (!areListenersSet) {
|
||||||
|
socketio.on('connect', () => onConnect());
|
||||||
|
|
||||||
|
socketio.on('disconnect', () => onDisconnect());
|
||||||
|
|
||||||
|
socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
|
||||||
|
|
||||||
|
socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
|
onGenerationResult(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('esrganResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
|
onESRGANResult(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('gfpganResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
|
onGFPGANResult(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
|
||||||
|
onIntermediateResult(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
|
||||||
|
onProgressUpdate(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
|
||||||
|
onGalleryImages(data)
|
||||||
|
);
|
||||||
|
|
||||||
|
socketio.on('processingCanceled', () => {
|
||||||
|
onProcessingCanceled();
|
||||||
|
});
|
||||||
|
|
||||||
|
socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => {
|
||||||
|
onImageDeleted(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
|
||||||
|
onInitialImageUploaded(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
|
||||||
|
onMaskImageUploaded(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
|
||||||
|
onSystemConfig(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
areListenersSet = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle redux actions caught by middleware.
|
||||||
|
*/
|
||||||
|
switch (action.type) {
|
||||||
|
case 'socketio/generateImage': {
|
||||||
|
emitGenerateImage();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/runESRGAN': {
|
||||||
|
emitRunESRGAN(action.payload);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/runGFPGAN': {
|
||||||
|
emitRunGFPGAN(action.payload);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/deleteImage': {
|
||||||
|
emitDeleteImage(action.payload);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/requestAllImages': {
|
||||||
|
emitRequestAllImages();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/cancelProcessing': {
|
||||||
|
emitCancelProcessing();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/uploadInitialImage': {
|
||||||
|
emitUploadInitialImage(action.payload);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/uploadMaskImage': {
|
||||||
|
emitUploadMaskImage(action.payload);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'socketio/requestSystemConfig': {
|
||||||
|
emitRequestSystemConfig();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next(action);
|
||||||
|
};
|
||||||
|
|
||||||
|
return middleware;
|
||||||
|
};
|
@ -1,53 +1,78 @@
|
|||||||
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||||
|
import { useDispatch, useSelector } from 'react-redux';
|
||||||
|
import type { TypedUseSelectorHook } from 'react-redux';
|
||||||
|
|
||||||
import { persistReducer } from 'redux-persist';
|
import { persistReducer } from 'redux-persist';
|
||||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||||
|
|
||||||
import sdReducer from '../features/sd/sdSlice';
|
import optionsReducer from '../features/options/optionsSlice';
|
||||||
import galleryReducer from '../features/gallery/gallerySlice';
|
import galleryReducer from '../features/gallery/gallerySlice';
|
||||||
import systemReducer from '../features/system/systemSlice';
|
import systemReducer from '../features/system/systemSlice';
|
||||||
import { socketioMiddleware } from './socketio';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
|
|
||||||
const reducers = combineReducers({
|
/**
|
||||||
sd: sdReducer,
|
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||||
gallery: galleryReducer,
|
*
|
||||||
system: systemReducer,
|
* While we definitely want generation parameters to be persisted, there are a number
|
||||||
});
|
* of things we do *not* want to be persisted across reloads:
|
||||||
|
* - Gallery/selected image (user may add/delete images from disk between page loads)
|
||||||
|
* - Connection/processing status
|
||||||
|
* - Availability of external libraries like ESRGAN/GFPGAN
|
||||||
|
*
|
||||||
|
* These can be blacklisted in redux-persist.
|
||||||
|
*
|
||||||
|
* The necesssary nested persistors with blacklists are configured below.
|
||||||
|
*
|
||||||
|
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
|
||||||
|
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
|
||||||
|
* on reload. Need to figure out a good way to handle this.
|
||||||
|
*/
|
||||||
|
|
||||||
const persistConfig = {
|
const rootPersistConfig = {
|
||||||
key: 'root',
|
key: 'root',
|
||||||
storage,
|
storage,
|
||||||
|
blacklist: ['gallery', 'system'],
|
||||||
};
|
};
|
||||||
|
|
||||||
const persistedReducer = persistReducer(persistConfig, reducers);
|
const systemPersistConfig = {
|
||||||
|
key: 'system',
|
||||||
|
storage,
|
||||||
|
blacklist: [
|
||||||
|
'isConnected',
|
||||||
|
'isProcessing',
|
||||||
|
'currentStep',
|
||||||
|
'socketId',
|
||||||
|
'isESRGANAvailable',
|
||||||
|
'isGFPGANAvailable',
|
||||||
|
'currentStep',
|
||||||
|
'totalSteps',
|
||||||
|
'currentIteration',
|
||||||
|
'totalIterations',
|
||||||
|
'currentStatus',
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
/*
|
const reducers = combineReducers({
|
||||||
The frontend needs to be distributed as a production build, so
|
options: optionsReducer,
|
||||||
we cannot reasonably ask users to edit the JS and specify the
|
gallery: galleryReducer,
|
||||||
host and port on which the socket.io server will run.
|
system: persistReducer(systemPersistConfig, systemReducer),
|
||||||
|
});
|
||||||
The solution is to allow server script to be run with arguments
|
|
||||||
(or just edited) providing the host and port. Then, the server
|
|
||||||
serves a route `/socketio_config` which responds with the host
|
|
||||||
and port.
|
|
||||||
|
|
||||||
When the frontend loads, it synchronously requests that route
|
|
||||||
and thus gets the host and port. This requires a suspicious
|
|
||||||
fetch somewhere, and the store setup seems like as good a place
|
|
||||||
as any to make this fetch request.
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
const persistedReducer = persistReducer(rootPersistConfig, reducers);
|
||||||
|
|
||||||
// Continue with store setup
|
// Continue with store setup
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
reducer: persistedReducer,
|
reducer: persistedReducer,
|
||||||
middleware: (getDefaultMiddleware) =>
|
middleware: (getDefaultMiddleware) =>
|
||||||
getDefaultMiddleware({
|
getDefaultMiddleware({
|
||||||
// redux-persist sometimes needs to have a function in redux, need to disable this check
|
// redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
}).concat(socketioMiddleware()),
|
}).concat(socketioMiddleware()),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Infer the `RootState` and `AppDispatch` types from the store itself
|
|
||||||
export type RootState = ReturnType<typeof store.getState>;
|
export type RootState = ReturnType<typeof store.getState>;
|
||||||
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
|
|
||||||
export type AppDispatch = typeof store.dispatch;
|
export type AppDispatch = typeof store.dispatch;
|
||||||
|
|
||||||
|
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||||
|
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||||
|
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||||
|
@ -33,5 +33,20 @@ export const theme = extendTheme({
|
|||||||
fontWeight: 'light',
|
fontWeight: 'light',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Button: {
|
||||||
|
variants: {
|
||||||
|
imageHoverIconButton: (props: StyleFunctionProps) => ({
|
||||||
|
bg: props.colorMode === 'dark' ? 'blackAlpha.700' : 'whiteAlpha.800',
|
||||||
|
color:
|
||||||
|
props.colorMode === 'dark' ? 'whiteAlpha.700' : 'blackAlpha.700',
|
||||||
|
_hover: {
|
||||||
|
bg:
|
||||||
|
props.colorMode === 'dark' ? 'blackAlpha.800' : 'whiteAlpha.800',
|
||||||
|
color:
|
||||||
|
props.colorMode === 'dark' ? 'whiteAlpha.900' : 'blackAlpha.900',
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
21
frontend/src/common/components/SDButton.tsx
Normal file
21
frontend/src/common/components/SDButton.tsx
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import { Button, ButtonProps } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends ButtonProps {
|
||||||
|
label: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reusable customized button component. Originally was more customized - now probably unecessary.
|
||||||
|
*
|
||||||
|
* TODO: Get rid of this.
|
||||||
|
*/
|
||||||
|
const SDButton = (props: Props) => {
|
||||||
|
const { label, size = 'sm', ...rest } = props;
|
||||||
|
return (
|
||||||
|
<Button size={size} {...rest}>
|
||||||
|
{label}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDButton;
|
@ -16,6 +16,9 @@ interface Props extends NumberInputProps {
|
|||||||
width?: string | number;
|
width?: string | number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Customized Chakra FormControl + NumberInput multi-part component.
|
||||||
|
*/
|
||||||
const SDNumberInput = (props: Props) => {
|
const SDNumberInput = (props: Props) => {
|
||||||
const {
|
const {
|
||||||
label,
|
label,
|
||||||
@ -31,7 +34,7 @@ const SDNumberInput = (props: Props) => {
|
|||||||
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
|
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
|
||||||
{label && (
|
{label && (
|
||||||
<FormLabel marginBottom={1}>
|
<FormLabel marginBottom={1}>
|
||||||
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
<Text fontSize={fontSize} whiteSpace="nowrap">
|
||||||
{label}
|
{label}
|
||||||
</Text>
|
</Text>
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
@ -42,7 +45,7 @@ const SDNumberInput = (props: Props) => {
|
|||||||
keepWithinRange={false}
|
keepWithinRange={false}
|
||||||
clampValueOnBlur={true}
|
clampValueOnBlur={true}
|
||||||
>
|
>
|
||||||
<NumberInputField fontSize={'md'}/>
|
<NumberInputField fontSize={'md'} />
|
||||||
<NumberInputStepper>
|
<NumberInputStepper>
|
||||||
<NumberIncrementStepper />
|
<NumberIncrementStepper />
|
||||||
<NumberDecrementStepper />
|
<NumberDecrementStepper />
|
56
frontend/src/common/components/SDSelect.tsx
Normal file
56
frontend/src/common/components/SDSelect.tsx
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Select,
|
||||||
|
SelectProps,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
interface Props extends SelectProps {
|
||||||
|
label: string;
|
||||||
|
validValues:
|
||||||
|
| Array<number | string>
|
||||||
|
| Array<{ key: string; value: string | number }>;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Customized Chakra FormControl + Select multi-part component.
|
||||||
|
*/
|
||||||
|
const SDSelect = (props: Props) => {
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
isDisabled,
|
||||||
|
validValues,
|
||||||
|
size = 'sm',
|
||||||
|
fontSize = 'md',
|
||||||
|
marginBottom = 1,
|
||||||
|
whiteSpace = 'nowrap',
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
|
return (
|
||||||
|
<FormControl isDisabled={isDisabled}>
|
||||||
|
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
||||||
|
<FormLabel marginBottom={marginBottom}>
|
||||||
|
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Select fontSize={fontSize} size={size} {...rest}>
|
||||||
|
{validValues.map((opt) => {
|
||||||
|
return typeof opt === 'string' || typeof opt === 'number' ? (
|
||||||
|
<option key={opt} value={opt}>
|
||||||
|
{opt}
|
||||||
|
</option>
|
||||||
|
) : (
|
||||||
|
<option key={opt.value} value={opt.value}>
|
||||||
|
{opt.key}
|
||||||
|
</option>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</Select>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDSelect;
|
@ -11,6 +11,9 @@ interface Props extends SwitchProps {
|
|||||||
width?: string | number;
|
width?: string | number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Customized Chakra FormControl + Switch multi-part component.
|
||||||
|
*/
|
||||||
const SDSwitch = (props: Props) => {
|
const SDSwitch = (props: Props) => {
|
||||||
const {
|
const {
|
||||||
label,
|
label,
|
||||||
@ -28,7 +31,7 @@ const SDSwitch = (props: Props) => {
|
|||||||
fontSize={fontSize}
|
fontSize={fontSize}
|
||||||
marginBottom={1}
|
marginBottom={1}
|
||||||
flexGrow={2}
|
flexGrow={2}
|
||||||
whiteSpace='nowrap'
|
whiteSpace="nowrap"
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</FormLabel>
|
</FormLabel>
|
104
frontend/src/common/hooks/useCheckParameters.ts
Normal file
104
frontend/src/common/hooks/useCheckParameters.ts
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { OptionsState } from '../../features/options/optionsSlice';
|
||||||
|
import { SystemState } from '../../features/system/systemSlice';
|
||||||
|
import { validateSeedWeights } from '../util/seedWeightPairs';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
prompt: options.prompt,
|
||||||
|
shouldGenerateVariations: options.shouldGenerateVariations,
|
||||||
|
seedWeights: options.seedWeights,
|
||||||
|
maskPath: options.maskPath,
|
||||||
|
initialImagePath: options.initialImagePath,
|
||||||
|
seed: options.seed,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks relevant pieces of state to confirm generation will not deterministically fail.
|
||||||
|
* This is used to prevent the 'Generate' button from being clicked.
|
||||||
|
*/
|
||||||
|
const useCheckParameters = (): boolean => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
seedWeights,
|
||||||
|
maskPath,
|
||||||
|
initialImagePath,
|
||||||
|
seed,
|
||||||
|
} = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
return useMemo(() => {
|
||||||
|
// Cannot generate without a prompt
|
||||||
|
if (!prompt) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate with a mask without img2img
|
||||||
|
if (maskPath && !initialImagePath) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: job queue
|
||||||
|
// Cannot generate if already processing an image
|
||||||
|
if (isProcessing) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate if not connected
|
||||||
|
if (!isConnected) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot generate variations without valid seed weights
|
||||||
|
if (
|
||||||
|
shouldGenerateVariations &&
|
||||||
|
(!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All good
|
||||||
|
return true;
|
||||||
|
}, [
|
||||||
|
prompt,
|
||||||
|
maskPath,
|
||||||
|
initialImagePath,
|
||||||
|
isProcessing,
|
||||||
|
isConnected,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
seedWeights,
|
||||||
|
seed,
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default useCheckParameters;
|
182
frontend/src/common/util/parameterTranslation.ts
Normal file
182
frontend/src/common/util/parameterTranslation.ts
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
/*
|
||||||
|
These functions translate frontend state into parameters
|
||||||
|
suitable for consumption by the backend, and vice-versa.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||||
|
import { OptionsState } from '../../features/options/optionsSlice';
|
||||||
|
import { SystemState } from '../../features/system/systemSlice';
|
||||||
|
import {
|
||||||
|
seedWeightsToString,
|
||||||
|
stringToSeedWeightsArray,
|
||||||
|
} from './seedWeightPairs';
|
||||||
|
import randomInt from './randomInt';
|
||||||
|
|
||||||
|
export const frontendToBackendParameters = (
|
||||||
|
optionsState: OptionsState,
|
||||||
|
systemState: SystemState
|
||||||
|
): { [key: string]: any } => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfgScale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
shouldUseInitImage,
|
||||||
|
img2imgStrength,
|
||||||
|
initialImagePath,
|
||||||
|
maskPath,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
shouldGenerateVariations,
|
||||||
|
variationAmount,
|
||||||
|
seedWeights,
|
||||||
|
shouldRunESRGAN,
|
||||||
|
upscalingLevel,
|
||||||
|
upscalingStrength,
|
||||||
|
shouldRunGFPGAN,
|
||||||
|
gfpganStrength,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
} = optionsState;
|
||||||
|
|
||||||
|
const { shouldDisplayInProgress } = systemState;
|
||||||
|
|
||||||
|
const generationParameters: { [k: string]: any } = {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfg_scale: cfgScale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler_name: sampler,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
progress_images: shouldDisplayInProgress,
|
||||||
|
};
|
||||||
|
|
||||||
|
generationParameters.seed = shouldRandomizeSeed
|
||||||
|
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
|
||||||
|
: seed;
|
||||||
|
|
||||||
|
if (shouldUseInitImage) {
|
||||||
|
generationParameters.init_img = initialImagePath;
|
||||||
|
generationParameters.strength = img2imgStrength;
|
||||||
|
generationParameters.fit = shouldFitToWidthHeight;
|
||||||
|
if (maskPath) {
|
||||||
|
generationParameters.init_mask = maskPath;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldGenerateVariations) {
|
||||||
|
generationParameters.variation_amount = variationAmount;
|
||||||
|
if (seedWeights) {
|
||||||
|
generationParameters.with_variations =
|
||||||
|
stringToSeedWeightsArray(seedWeights);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
generationParameters.variation_amount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let esrganParameters: false | { [k: string]: any } = false;
|
||||||
|
let gfpganParameters: false | { [k: string]: any } = false;
|
||||||
|
|
||||||
|
if (shouldRunESRGAN) {
|
||||||
|
esrganParameters = {
|
||||||
|
level: upscalingLevel,
|
||||||
|
strength: upscalingStrength,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldRunGFPGAN) {
|
||||||
|
gfpganParameters = {
|
||||||
|
strength: gfpganStrength,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
generationParameters,
|
||||||
|
esrganParameters,
|
||||||
|
gfpganParameters,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export const backendToFrontendParameters = (parameters: {
|
||||||
|
[key: string]: any;
|
||||||
|
}) => {
|
||||||
|
const {
|
||||||
|
prompt,
|
||||||
|
iterations,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
sampler_name,
|
||||||
|
seed,
|
||||||
|
seamless,
|
||||||
|
progress_images,
|
||||||
|
variation_amount,
|
||||||
|
with_variations,
|
||||||
|
gfpgan_strength,
|
||||||
|
upscale,
|
||||||
|
init_img,
|
||||||
|
init_mask,
|
||||||
|
strength,
|
||||||
|
} = parameters;
|
||||||
|
|
||||||
|
const options: { [key: string]: any } = {
|
||||||
|
shouldDisplayInProgress: progress_images,
|
||||||
|
// init
|
||||||
|
shouldGenerateVariations: false,
|
||||||
|
shouldRunESRGAN: false,
|
||||||
|
shouldRunGFPGAN: false,
|
||||||
|
initialImagePath: '',
|
||||||
|
maskPath: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
if (variation_amount > 0) {
|
||||||
|
options.shouldGenerateVariations = true;
|
||||||
|
options.variationAmount = variation_amount;
|
||||||
|
if (with_variations) {
|
||||||
|
options.seedWeights = seedWeightsToString(with_variations);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gfpgan_strength > 0) {
|
||||||
|
options.shouldRunGFPGAN = true;
|
||||||
|
options.gfpganStrength = gfpgan_strength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (upscale) {
|
||||||
|
options.shouldRunESRGAN = true;
|
||||||
|
options.upscalingLevel = upscale[0];
|
||||||
|
options.upscalingStrength = upscale[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (init_img) {
|
||||||
|
options.shouldUseInitImage = true;
|
||||||
|
options.initialImagePath = init_img;
|
||||||
|
options.strength = strength;
|
||||||
|
if (init_mask) {
|
||||||
|
options.maskPath = init_mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we had a prompt, add all the metadata, but if we don't have a prompt,
|
||||||
|
// we must have only done ESRGAN or GFPGAN so do not add that metadata
|
||||||
|
if (prompt) {
|
||||||
|
options.prompt = prompt;
|
||||||
|
options.iterations = iterations;
|
||||||
|
options.steps = steps;
|
||||||
|
options.cfgScale = cfg_scale;
|
||||||
|
options.height = height;
|
||||||
|
options.width = width;
|
||||||
|
options.sampler = sampler_name;
|
||||||
|
options.seed = seed;
|
||||||
|
options.seamless = seamless;
|
||||||
|
}
|
||||||
|
|
||||||
|
return options;
|
||||||
|
};
|
16
frontend/src/common/util/promptToString.ts
Normal file
16
frontend/src/common/util/promptToString.ts
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
|
|
||||||
|
const promptToString = (prompt: InvokeAI.Prompt): string => {
|
||||||
|
if (prompt.length === 1) {
|
||||||
|
return prompt[0].prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
.map(
|
||||||
|
(promptItem: InvokeAI.PromptItem): string =>
|
||||||
|
`${promptItem.prompt}:${promptItem.weight}`
|
||||||
|
)
|
||||||
|
.join(' ');
|
||||||
|
};
|
||||||
|
|
||||||
|
export default promptToString;
|
68
frontend/src/common/util/seedWeightPairs.ts
Normal file
68
frontend/src/common/util/seedWeightPairs.ts
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
|
|
||||||
|
export const stringToSeedWeights = (
|
||||||
|
string: string
|
||||||
|
): InvokeAI.SeedWeights | boolean => {
|
||||||
|
const stringPairs = string.split(',');
|
||||||
|
const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||||
|
const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
|
||||||
|
return { seed: parseInt(p[0]), weight: parseFloat(p[1]) };
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!validateSeedWeights(pairs)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pairs;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const validateSeedWeights = (
|
||||||
|
seedWeights: InvokeAI.SeedWeights | string
|
||||||
|
): boolean => {
|
||||||
|
return typeof seedWeights === 'string'
|
||||||
|
? Boolean(stringToSeedWeights(seedWeights))
|
||||||
|
: Boolean(
|
||||||
|
seedWeights.length &&
|
||||||
|
!seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
|
||||||
|
const { seed, weight } = pair;
|
||||||
|
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
||||||
|
const isWeightValid =
|
||||||
|
!isNaN(parseInt(weight.toString(), 10)) &&
|
||||||
|
weight >= 0 &&
|
||||||
|
weight <= 1;
|
||||||
|
return !(isSeedValid && isWeightValid);
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const seedWeightsToString = (
|
||||||
|
seedWeights: InvokeAI.SeedWeights
|
||||||
|
): string => {
|
||||||
|
return seedWeights.reduce((acc, pair, i, arr) => {
|
||||||
|
const { seed, weight } = pair;
|
||||||
|
acc += `${seed}:${weight}`;
|
||||||
|
if (i !== arr.length - 1) {
|
||||||
|
acc += ',';
|
||||||
|
}
|
||||||
|
return acc;
|
||||||
|
}, '');
|
||||||
|
};
|
||||||
|
|
||||||
|
export const seedWeightsToArray = (
|
||||||
|
seedWeights: InvokeAI.SeedWeights
|
||||||
|
): Array<Array<number>> => {
|
||||||
|
return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
|
||||||
|
pair.seed,
|
||||||
|
pair.weight,
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const stringToSeedWeightsArray = (
|
||||||
|
string: string
|
||||||
|
): Array<Array<number>> => {
|
||||||
|
const stringPairs = string.split(',');
|
||||||
|
const arrPairs = stringPairs.map((p) => p.split(':'));
|
||||||
|
return arrPairs.map(
|
||||||
|
(p: Array<string>): Array<number> => [parseInt(p[0]), parseFloat(p[1])]
|
||||||
|
);
|
||||||
|
};
|
@ -1,16 +0,0 @@
|
|||||||
import { Button, ButtonProps } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
interface Props extends ButtonProps {
|
|
||||||
label: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const SDButton = (props: Props) => {
|
|
||||||
const { label, size = 'sm', ...rest } = props;
|
|
||||||
return (
|
|
||||||
<Button size={size} {...rest}>
|
|
||||||
{label}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SDButton;
|
|
@ -1,57 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Select,
|
|
||||||
SelectProps,
|
|
||||||
Text,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
interface Props extends SelectProps {
|
|
||||||
label: string;
|
|
||||||
validValues:
|
|
||||||
| Array<number | string>
|
|
||||||
| Array<{ key: string; value: string | number }>;
|
|
||||||
}
|
|
||||||
|
|
||||||
const SDSelect = (props: Props) => {
|
|
||||||
const {
|
|
||||||
label,
|
|
||||||
isDisabled,
|
|
||||||
validValues,
|
|
||||||
size = 'sm',
|
|
||||||
fontSize = 'md',
|
|
||||||
marginBottom = 1,
|
|
||||||
whiteSpace = 'nowrap',
|
|
||||||
...rest
|
|
||||||
} = props;
|
|
||||||
return (
|
|
||||||
<FormControl isDisabled={isDisabled}>
|
|
||||||
<Flex justifyContent={'space-between'} alignItems={'center'}>
|
|
||||||
<FormLabel
|
|
||||||
marginBottom={marginBottom}
|
|
||||||
>
|
|
||||||
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
|
|
||||||
{label}
|
|
||||||
</Text>
|
|
||||||
</FormLabel>
|
|
||||||
<Select fontSize={fontSize} size={size} {...rest}>
|
|
||||||
{validValues.map((opt) => {
|
|
||||||
return typeof opt === 'string' ||
|
|
||||||
typeof opt === 'number' ? (
|
|
||||||
<option key={opt} value={opt}>
|
|
||||||
{opt}
|
|
||||||
</option>
|
|
||||||
) : (
|
|
||||||
<option key={opt.value} value={opt.value}>
|
|
||||||
{opt.key}
|
|
||||||
</option>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</Select>
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SDSelect;
|
|
@ -1,161 +0,0 @@
|
|||||||
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
|
|
||||||
import { useState } from 'react';
|
|
||||||
import ImageMetadataViewer from './ImageMetadataViewer';
|
|
||||||
import DeleteImageModalButton from './DeleteImageModalButton';
|
|
||||||
import SDButton from '../../components/SDButton';
|
|
||||||
import { runESRGAN, runGFPGAN } from '../../app/socketio';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { SystemState } from '../system/systemSlice';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
|
|
||||||
const height = 'calc(100vh - 238px)';
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isProcessing: system.isProcessing,
|
|
||||||
isConnected: system.isConnected,
|
|
||||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
|
||||||
isESRGANAvailable: system.isESRGANAvailable,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const CurrentImage = () => {
|
|
||||||
const { currentImage, intermediateImage } = useAppSelector(
|
|
||||||
(state: RootState) => state.gallery
|
|
||||||
);
|
|
||||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
|
||||||
useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const bgColor = useColorModeValue(
|
|
||||||
'rgba(255, 255, 255, 0.85)',
|
|
||||||
'rgba(0, 0, 0, 0.8)'
|
|
||||||
);
|
|
||||||
|
|
||||||
const [shouldShowImageDetails, setShouldShowImageDetails] =
|
|
||||||
useState<boolean>(false);
|
|
||||||
|
|
||||||
const imageToDisplay = intermediateImage || currentImage;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
|
|
||||||
{imageToDisplay && (
|
|
||||||
<Flex gap={2}>
|
|
||||||
<SDButton
|
|
||||||
label='Use as initial image'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setInitialImagePath(imageToDisplay.url))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<SDButton
|
|
||||||
label='Use all'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setAllParameters(imageToDisplay.metadata))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<SDButton
|
|
||||||
label='Use seed'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
isDisabled={!imageToDisplay.metadata.seed}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setSeed(imageToDisplay.metadata.seed!))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<SDButton
|
|
||||||
label='Upscale'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
isDisabled={
|
|
||||||
!isESRGANAvailable ||
|
|
||||||
Boolean(intermediateImage) ||
|
|
||||||
!(isConnected && !isProcessing)
|
|
||||||
}
|
|
||||||
onClick={() => dispatch(runESRGAN(imageToDisplay))}
|
|
||||||
/>
|
|
||||||
<SDButton
|
|
||||||
label='Fix faces'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
isDisabled={
|
|
||||||
!isGFPGANAvailable ||
|
|
||||||
Boolean(intermediateImage) ||
|
|
||||||
!(isConnected && !isProcessing)
|
|
||||||
}
|
|
||||||
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
|
|
||||||
/>
|
|
||||||
<SDButton
|
|
||||||
label='Details'
|
|
||||||
colorScheme={'gray'}
|
|
||||||
variant={shouldShowImageDetails ? 'solid' : 'outline'}
|
|
||||||
borderWidth={1}
|
|
||||||
flexGrow={1}
|
|
||||||
onClick={() =>
|
|
||||||
setShouldShowImageDetails(!shouldShowImageDetails)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<DeleteImageModalButton image={imageToDisplay}>
|
|
||||||
<SDButton
|
|
||||||
label='Delete'
|
|
||||||
colorScheme={'red'}
|
|
||||||
flexGrow={1}
|
|
||||||
variant={'outline'}
|
|
||||||
isDisabled={Boolean(intermediateImage)}
|
|
||||||
/>
|
|
||||||
</DeleteImageModalButton>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
<Center height={height} position={'relative'}>
|
|
||||||
{imageToDisplay && (
|
|
||||||
<Image
|
|
||||||
src={imageToDisplay.url}
|
|
||||||
fit='contain'
|
|
||||||
maxWidth={'100%'}
|
|
||||||
maxHeight={'100%'}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{imageToDisplay && shouldShowImageDetails && (
|
|
||||||
<Flex
|
|
||||||
width={'100%'}
|
|
||||||
height={'100%'}
|
|
||||||
position={'absolute'}
|
|
||||||
top={0}
|
|
||||||
left={0}
|
|
||||||
p={3}
|
|
||||||
boxSizing='border-box'
|
|
||||||
backgroundColor={bgColor}
|
|
||||||
overflow='scroll'
|
|
||||||
>
|
|
||||||
<ImageMetadataViewer image={imageToDisplay} />
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Center>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default CurrentImage;
|
|
155
frontend/src/features/gallery/CurrentImageButtons.tsx
Normal file
155
frontend/src/features/gallery/CurrentImageButtons.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
|
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import {
|
||||||
|
setAllParameters,
|
||||||
|
setInitialImagePath,
|
||||||
|
setSeed,
|
||||||
|
} from '../options/optionsSlice';
|
||||||
|
import DeleteImageModal from './DeleteImageModal';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
import SDButton from '../../common/components/SDButton';
|
||||||
|
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
type CurrentImageButtonsProps = {
|
||||||
|
image: InvokeAI.Image;
|
||||||
|
shouldShowImageDetails: boolean;
|
||||||
|
setShouldShowImageDetails: (b: boolean) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Row of buttons for common actions:
|
||||||
|
* Use as init image, use all params, use seed, upscale, fix faces, details, delete.
|
||||||
|
*/
|
||||||
|
const CurrentImageButtons = ({
|
||||||
|
image,
|
||||||
|
shouldShowImageDetails,
|
||||||
|
setShouldShowImageDetails,
|
||||||
|
}: CurrentImageButtonsProps) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { intermediateImage } = useAppSelector(
|
||||||
|
(state: RootState) => state.gallery
|
||||||
|
);
|
||||||
|
|
||||||
|
const { upscalingLevel, gfpganStrength } = useAppSelector(
|
||||||
|
(state: RootState) => state.options
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const handleClickUseAsInitialImage = () =>
|
||||||
|
dispatch(setInitialImagePath(image.url));
|
||||||
|
|
||||||
|
const handleClickUseAllParameters = () =>
|
||||||
|
dispatch(setAllParameters(image.metadata));
|
||||||
|
|
||||||
|
// Non-null assertion: this button is disabled if there is no seed.
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed));
|
||||||
|
const handleClickUpscale = () => dispatch(runESRGAN(image));
|
||||||
|
|
||||||
|
const handleClickFixFaces = () => dispatch(runGFPGAN(image));
|
||||||
|
|
||||||
|
const handleClickShowImageDetails = () =>
|
||||||
|
setShouldShowImageDetails(!shouldShowImageDetails);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDButton
|
||||||
|
label="Use as initial image"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
onClick={handleClickUseAsInitialImage}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label="Use all"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
|
||||||
|
onClick={handleClickUseAllParameters}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label="Use seed"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={!image.metadata.image.seed}
|
||||||
|
onClick={handleClickUseSeed}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<SDButton
|
||||||
|
label="Upscale"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={
|
||||||
|
!isESRGANAvailable ||
|
||||||
|
Boolean(intermediateImage) ||
|
||||||
|
!(isConnected && !isProcessing) ||
|
||||||
|
!upscalingLevel
|
||||||
|
}
|
||||||
|
onClick={handleClickUpscale}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label="Fix faces"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={
|
||||||
|
!isGFPGANAvailable ||
|
||||||
|
Boolean(intermediateImage) ||
|
||||||
|
!(isConnected && !isProcessing) ||
|
||||||
|
!gfpganStrength
|
||||||
|
}
|
||||||
|
onClick={handleClickFixFaces}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label="Details"
|
||||||
|
colorScheme={'gray'}
|
||||||
|
variant={shouldShowImageDetails ? 'solid' : 'outline'}
|
||||||
|
borderWidth={1}
|
||||||
|
flexGrow={1}
|
||||||
|
onClick={handleClickShowImageDetails}
|
||||||
|
/>
|
||||||
|
<DeleteImageModal image={image}>
|
||||||
|
<SDButton
|
||||||
|
label="Delete"
|
||||||
|
colorScheme={'red'}
|
||||||
|
flexGrow={1}
|
||||||
|
variant={'outline'}
|
||||||
|
isDisabled={Boolean(intermediateImage)}
|
||||||
|
/>
|
||||||
|
</DeleteImageModal>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CurrentImageButtons;
|
67
frontend/src/features/gallery/CurrentImageDisplay.tsx
Normal file
67
frontend/src/features/gallery/CurrentImageDisplay.tsx
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import ImageMetadataViewer from './ImageMetadataViewer';
|
||||||
|
import CurrentImageButtons from './CurrentImageButtons';
|
||||||
|
|
||||||
|
// TODO: With CSS Grid I had a hard time centering the image in a grid item. This is needed for that.
|
||||||
|
const height = 'calc(100vh - 238px)';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Displays the current image if there is one, plus associated actions.
|
||||||
|
*/
|
||||||
|
const CurrentImageDisplay = () => {
|
||||||
|
const { currentImage, intermediateImage } = useAppSelector(
|
||||||
|
(state: RootState) => state.gallery
|
||||||
|
);
|
||||||
|
|
||||||
|
const bgColor = useColorModeValue(
|
||||||
|
'rgba(255, 255, 255, 0.85)',
|
||||||
|
'rgba(0, 0, 0, 0.8)'
|
||||||
|
);
|
||||||
|
|
||||||
|
const [shouldShowImageDetails, setShouldShowImageDetails] =
|
||||||
|
useState<boolean>(false);
|
||||||
|
|
||||||
|
const imageToDisplay = intermediateImage || currentImage;
|
||||||
|
|
||||||
|
return imageToDisplay ? (
|
||||||
|
<Flex direction={'column'} borderWidth={1} rounded={'md'} p={2} gap={2}>
|
||||||
|
<CurrentImageButtons
|
||||||
|
image={imageToDisplay}
|
||||||
|
shouldShowImageDetails={shouldShowImageDetails}
|
||||||
|
setShouldShowImageDetails={setShouldShowImageDetails}
|
||||||
|
/>
|
||||||
|
<Center height={height} position={'relative'}>
|
||||||
|
<Image
|
||||||
|
src={imageToDisplay.url}
|
||||||
|
fit="contain"
|
||||||
|
maxWidth={'100%'}
|
||||||
|
maxHeight={'100%'}
|
||||||
|
/>
|
||||||
|
{shouldShowImageDetails && (
|
||||||
|
<Flex
|
||||||
|
width={'100%'}
|
||||||
|
height={'100%'}
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
p={3}
|
||||||
|
boxSizing="border-box"
|
||||||
|
backgroundColor={bgColor}
|
||||||
|
overflow="scroll"
|
||||||
|
>
|
||||||
|
<ImageMetadataViewer image={imageToDisplay} />
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Center>
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Center height={'100%'} position={'relative'}>
|
||||||
|
<Text size={'xl'}>No image selected</Text>
|
||||||
|
</Center>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default CurrentImageDisplay;
|
125
frontend/src/features/gallery/DeleteImageModal.tsx
Normal file
125
frontend/src/features/gallery/DeleteImageModal.tsx
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import {
|
||||||
|
Text,
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
useDisclosure,
|
||||||
|
Button,
|
||||||
|
Switch,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Flex,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import {
|
||||||
|
ChangeEvent,
|
||||||
|
cloneElement,
|
||||||
|
forwardRef,
|
||||||
|
ReactElement,
|
||||||
|
SyntheticEvent,
|
||||||
|
useRef,
|
||||||
|
} from 'react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { deleteImage } from '../../app/socketio/actions';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
|
||||||
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
|
|
||||||
|
interface DeleteImageModalProps {
|
||||||
|
/**
|
||||||
|
* Component which, on click, should delete the image/open the modal.
|
||||||
|
*/
|
||||||
|
children: ReactElement;
|
||||||
|
/**
|
||||||
|
* The image to delete.
|
||||||
|
*/
|
||||||
|
image: InvokeAI.Image;
|
||||||
|
}
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => system.shouldConfirmOnDelete
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Needs a child, which will act as the button to delete an image.
|
||||||
|
* If system.shouldConfirmOnDelete is true, a confirmation modal is displayed.
|
||||||
|
* If it is false, the image is deleted immediately.
|
||||||
|
* The confirmation modal has a "Don't ask me again" switch to set the boolean.
|
||||||
|
*/
|
||||||
|
const DeleteImageModal = forwardRef(
|
||||||
|
({ image, children }: DeleteImageModalProps, ref) => {
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const shouldConfirmOnDelete = useAppSelector(systemSelector);
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const handleClickDelete = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
shouldConfirmOnDelete ? onOpen() : handleDelete();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDelete = () => {
|
||||||
|
dispatch(deleteImage(image));
|
||||||
|
onClose();
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleChangeShouldConfirmOnDelete = (
|
||||||
|
e: ChangeEvent<HTMLInputElement>
|
||||||
|
) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{cloneElement(children, {
|
||||||
|
// TODO: This feels wrong.
|
||||||
|
onClick: handleClickDelete,
|
||||||
|
ref: ref,
|
||||||
|
})}
|
||||||
|
|
||||||
|
<AlertDialog
|
||||||
|
isOpen={isOpen}
|
||||||
|
leastDestructiveRef={cancelRef}
|
||||||
|
onClose={onClose}
|
||||||
|
>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Delete image
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
<Flex direction={'column'} gap={5}>
|
||||||
|
<Text>
|
||||||
|
Are you sure? You can't undo this action afterwards.
|
||||||
|
</Text>
|
||||||
|
<FormControl>
|
||||||
|
<Flex alignItems={'center'}>
|
||||||
|
<FormLabel mb={0}>Don't ask me again</FormLabel>
|
||||||
|
<Switch
|
||||||
|
checked={!shouldConfirmOnDelete}
|
||||||
|
onChange={handleChangeShouldConfirmOnDelete}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</AlertDialogBody>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<Button ref={cancelRef} onClick={onClose}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button colorScheme="red" onClick={handleDelete} ml={3}>
|
||||||
|
Delete
|
||||||
|
</Button>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export default DeleteImageModal;
|
@ -1,94 +0,0 @@
|
|||||||
import {
|
|
||||||
IconButtonProps,
|
|
||||||
Modal,
|
|
||||||
ModalBody,
|
|
||||||
ModalCloseButton,
|
|
||||||
ModalContent,
|
|
||||||
ModalFooter,
|
|
||||||
ModalHeader,
|
|
||||||
ModalOverlay,
|
|
||||||
Text,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import {
|
|
||||||
cloneElement,
|
|
||||||
ReactElement,
|
|
||||||
SyntheticEvent,
|
|
||||||
} from 'react';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { deleteImage } from '../../app/socketio';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SDButton from '../../components/SDButton';
|
|
||||||
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
|
|
||||||
import { SDImage } from './gallerySlice';
|
|
||||||
|
|
||||||
interface Props extends IconButtonProps {
|
|
||||||
image: SDImage;
|
|
||||||
'aria-label': string;
|
|
||||||
children: ReactElement;
|
|
||||||
}
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => system.shouldConfirmOnDelete
|
|
||||||
);
|
|
||||||
|
|
||||||
/*
|
|
||||||
TODO: The modal and button to open it should be two different components,
|
|
||||||
but their state is closely related and I'm not sure how best to accomplish it.
|
|
||||||
*/
|
|
||||||
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const shouldConfirmOnDelete = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const handleClickDelete = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
shouldConfirmOnDelete ? onOpen() : handleDelete();
|
|
||||||
};
|
|
||||||
|
|
||||||
const { image, children } = props;
|
|
||||||
|
|
||||||
const handleDelete = () => {
|
|
||||||
dispatch(deleteImage(image));
|
|
||||||
onClose();
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleDeleteAndDontAsk = () => {
|
|
||||||
dispatch(deleteImage(image));
|
|
||||||
dispatch(setShouldConfirmOnDelete(false));
|
|
||||||
onClose();
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
{cloneElement(children, {
|
|
||||||
onClick: handleClickDelete,
|
|
||||||
})}
|
|
||||||
|
|
||||||
<Modal isOpen={isOpen} onClose={onClose}>
|
|
||||||
<ModalOverlay />
|
|
||||||
<ModalContent>
|
|
||||||
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
|
|
||||||
<ModalCloseButton />
|
|
||||||
<ModalBody>
|
|
||||||
<Text>It will be deleted forever!</Text>
|
|
||||||
</ModalBody>
|
|
||||||
|
|
||||||
<ModalFooter justifyContent={'space-between'}>
|
|
||||||
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
|
|
||||||
<SDButton
|
|
||||||
label={"Yes, and don't ask me again"}
|
|
||||||
colorScheme='red'
|
|
||||||
onClick={handleDeleteAndDontAsk}
|
|
||||||
/>
|
|
||||||
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
|
|
||||||
</ModalFooter>
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default DeleteImageModalButton;
|
|
143
frontend/src/features/gallery/HoverableImage.tsx
Normal file
143
frontend/src/features/gallery/HoverableImage.tsx
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Flex,
|
||||||
|
Icon,
|
||||||
|
IconButton,
|
||||||
|
Image,
|
||||||
|
Tooltip,
|
||||||
|
useColorModeValue,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from '../../app/store';
|
||||||
|
import { setCurrentImage } from './gallerySlice';
|
||||||
|
import { FaCheck, FaSeedling, FaTrashAlt } from 'react-icons/fa';
|
||||||
|
import DeleteImageModal from './DeleteImageModal';
|
||||||
|
import { memo, SyntheticEvent, useState } from 'react';
|
||||||
|
import { setAllParameters, setSeed } from '../options/optionsSlice';
|
||||||
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
|
|
||||||
|
interface HoverableImageProps {
|
||||||
|
image: InvokeAI.Image;
|
||||||
|
isSelected: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const memoEqualityCheck = (
|
||||||
|
prev: HoverableImageProps,
|
||||||
|
next: HoverableImageProps
|
||||||
|
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||||
|
*/
|
||||||
|
const HoverableImage = memo((props: HoverableImageProps) => {
|
||||||
|
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const checkColor = useColorModeValue('green.600', 'green.300');
|
||||||
|
const bgColor = useColorModeValue('gray.200', 'gray.700');
|
||||||
|
const bgGradient = useColorModeValue(
|
||||||
|
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
|
||||||
|
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
|
||||||
|
);
|
||||||
|
|
||||||
|
const { image, isSelected } = props;
|
||||||
|
const { url, uuid, metadata } = image;
|
||||||
|
|
||||||
|
const handleMouseOver = () => setIsHovered(true);
|
||||||
|
const handleMouseOut = () => setIsHovered(false);
|
||||||
|
|
||||||
|
const handleClickSetAllParameters = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setAllParameters(metadata));
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleClickSetSeed = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
// Non-null assertion: this button is not rendered unless this exists
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
dispatch(setSeed(image.metadata.image.seed));
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleClickImage = () => dispatch(setCurrentImage(image));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box position={'relative'} key={uuid}>
|
||||||
|
<Image
|
||||||
|
width={120}
|
||||||
|
height={120}
|
||||||
|
objectFit="cover"
|
||||||
|
rounded={'md'}
|
||||||
|
src={url}
|
||||||
|
loading={'lazy'}
|
||||||
|
backgroundColor={bgColor}
|
||||||
|
/>
|
||||||
|
<Flex
|
||||||
|
cursor={'pointer'}
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
rounded={'md'}
|
||||||
|
width="100%"
|
||||||
|
height="100%"
|
||||||
|
alignItems={'center'}
|
||||||
|
justifyContent={'center'}
|
||||||
|
background={isSelected ? bgGradient : undefined}
|
||||||
|
onClick={handleClickImage}
|
||||||
|
onMouseOver={handleMouseOver}
|
||||||
|
onMouseOut={handleMouseOut}
|
||||||
|
>
|
||||||
|
{isSelected && (
|
||||||
|
<Icon fill={checkColor} width={'50%'} height={'50%'} as={FaCheck} />
|
||||||
|
)}
|
||||||
|
{isHovered && (
|
||||||
|
<Flex
|
||||||
|
direction={'column'}
|
||||||
|
gap={1}
|
||||||
|
position={'absolute'}
|
||||||
|
top={1}
|
||||||
|
right={1}
|
||||||
|
>
|
||||||
|
<Tooltip label={'Delete image'}>
|
||||||
|
<DeleteImageModal image={image}>
|
||||||
|
<IconButton
|
||||||
|
colorScheme="red"
|
||||||
|
aria-label="Delete image"
|
||||||
|
icon={<FaTrashAlt />}
|
||||||
|
size="xs"
|
||||||
|
variant={'imageHoverIconButton'}
|
||||||
|
fontSize={14}
|
||||||
|
/>
|
||||||
|
</DeleteImageModal>
|
||||||
|
</Tooltip>
|
||||||
|
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
|
||||||
|
<Tooltip label="Use all parameters">
|
||||||
|
<IconButton
|
||||||
|
aria-label="Use all parameters"
|
||||||
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
|
size="xs"
|
||||||
|
fontSize={18}
|
||||||
|
variant={'imageHoverIconButton'}
|
||||||
|
onClickCapture={handleClickSetAllParameters}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
{image.metadata.image.seed && (
|
||||||
|
<Tooltip label="Use seed">
|
||||||
|
<IconButton
|
||||||
|
aria-label="Use seed"
|
||||||
|
icon={<FaSeedling />}
|
||||||
|
size="xs"
|
||||||
|
fontSize={16}
|
||||||
|
variant={'imageHoverIconButton'}
|
||||||
|
onClickCapture={handleClickSetSeed}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}, memoEqualityCheck);
|
||||||
|
|
||||||
|
export default HoverableImage;
|
39
frontend/src/features/gallery/ImageGallery.tsx
Normal file
39
frontend/src/features/gallery/ImageGallery.tsx
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import { Center, Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import HoverableImage from './HoverableImage';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple image gallery.
|
||||||
|
*/
|
||||||
|
const ImageGallery = () => {
|
||||||
|
const { images, currentImageUuid } = useAppSelector(
|
||||||
|
(state: RootState) => state.gallery
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* I don't like that this needs to rerender whenever the current image is changed.
|
||||||
|
* What if we have a large number of images? I suppose pagination (planned) will
|
||||||
|
* mitigate this issue.
|
||||||
|
*
|
||||||
|
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
|
||||||
|
*/
|
||||||
|
|
||||||
|
return images.length ? (
|
||||||
|
<Flex gap={2} wrap="wrap" pb={2}>
|
||||||
|
{[...images].reverse().map((image) => {
|
||||||
|
const { uuid } = image;
|
||||||
|
const isSelected = currentImageUuid === uuid;
|
||||||
|
return (
|
||||||
|
<HoverableImage key={uuid} image={image} isSelected={isSelected} />
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</Flex>
|
||||||
|
) : (
|
||||||
|
<Center height={'100%'} position={'relative'}>
|
||||||
|
<Text size={'xl'}>No images in gallery</Text>
|
||||||
|
</Center>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageGallery;
|
@ -1,124 +1,326 @@
|
|||||||
import {
|
import {
|
||||||
Center,
|
Box,
|
||||||
Flex,
|
Center,
|
||||||
IconButton,
|
Flex,
|
||||||
Link,
|
IconButton,
|
||||||
List,
|
Link,
|
||||||
ListItem,
|
Text,
|
||||||
Text,
|
Tooltip,
|
||||||
|
useColorModeValue,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { FaPlus } from 'react-icons/fa';
|
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||||
import { PARAMETERS } from '../../app/constants';
|
import { memo } from 'react';
|
||||||
import { useAppDispatch } from '../../app/hooks';
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
import SDButton from '../../components/SDButton';
|
import { useAppDispatch } from '../../app/store';
|
||||||
import { setAllParameters, setParameter } from '../sd/sdSlice';
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
import { SDImage, SDMetadata } from './gallerySlice';
|
import {
|
||||||
|
setCfgScale,
|
||||||
|
setGfpganStrength,
|
||||||
|
setHeight,
|
||||||
|
setImg2imgStrength,
|
||||||
|
setInitialImagePath,
|
||||||
|
setMaskPath,
|
||||||
|
setPrompt,
|
||||||
|
setSampler,
|
||||||
|
setSeed,
|
||||||
|
setSeedWeights,
|
||||||
|
setShouldFitToWidthHeight,
|
||||||
|
setSteps,
|
||||||
|
setUpscalingLevel,
|
||||||
|
setUpscalingStrength,
|
||||||
|
setWidth,
|
||||||
|
} from '../options/optionsSlice';
|
||||||
|
import promptToString from '../../common/util/promptToString';
|
||||||
|
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
|
||||||
|
import { FaCopy } from 'react-icons/fa';
|
||||||
|
|
||||||
type Props = {
|
type MetadataItemProps = {
|
||||||
image: SDImage;
|
isLink?: boolean;
|
||||||
|
label: string;
|
||||||
|
onClick?: () => void;
|
||||||
|
value: number | string | boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImageMetadataViewer = ({ image }: Props) => {
|
/**
|
||||||
const dispatch = useAppDispatch();
|
* Component to display an individual metadata item or parameter.
|
||||||
|
*/
|
||||||
|
const MetadataItem = ({ label, value, onClick, isLink }: MetadataItemProps) => {
|
||||||
|
return (
|
||||||
|
<Flex gap={2}>
|
||||||
|
{onClick && (
|
||||||
|
<Tooltip label={`Recall ${label}`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label="Use this parameter"
|
||||||
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
|
size={'xs'}
|
||||||
|
variant={'ghost'}
|
||||||
|
fontSize={20}
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
<Text fontWeight={'semibold'} whiteSpace={'nowrap'}>
|
||||||
|
{label}:
|
||||||
|
</Text>
|
||||||
|
{isLink ? (
|
||||||
|
<Link href={value.toString()} isExternal wordBreak={'break-all'}>
|
||||||
|
{value.toString()} <ExternalLinkIcon mx="2px" />
|
||||||
|
</Link>
|
||||||
|
) : (
|
||||||
|
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
|
||||||
|
{value.toString()}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const keys = Object.keys(PARAMETERS);
|
type ImageMetadataViewerProps = {
|
||||||
|
image: InvokeAI.Image;
|
||||||
|
};
|
||||||
|
|
||||||
const metadata: Array<{
|
// TODO: I don't know if this is needed.
|
||||||
label: string;
|
const memoEqualityCheck = (
|
||||||
key: string;
|
prev: ImageMetadataViewerProps,
|
||||||
value: string | number | boolean;
|
next: ImageMetadataViewerProps
|
||||||
}> = [];
|
) => prev.image.uuid === next.image.uuid;
|
||||||
|
|
||||||
keys.forEach((key) => {
|
// TODO: Show more interesting information in this component.
|
||||||
const value = image.metadata[key as keyof SDMetadata];
|
|
||||||
if (value !== undefined) {
|
|
||||||
metadata.push({ label: PARAMETERS[key], key, value });
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return (
|
/**
|
||||||
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
|
* Image metadata viewer overlays currently selected image and provides
|
||||||
<SDButton
|
* access to any of its metadata for use in processing.
|
||||||
label='Use all parameters'
|
*/
|
||||||
colorScheme={'gray'}
|
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||||
padding={2}
|
const dispatch = useAppDispatch();
|
||||||
isDisabled={metadata.length === 0}
|
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
|
||||||
onClick={() => dispatch(setAllParameters(image.metadata))}
|
|
||||||
|
const metadata = image.metadata.image;
|
||||||
|
const {
|
||||||
|
type,
|
||||||
|
postprocessing,
|
||||||
|
sampler,
|
||||||
|
prompt,
|
||||||
|
seed,
|
||||||
|
variations,
|
||||||
|
steps,
|
||||||
|
cfg_scale,
|
||||||
|
seamless,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
strength,
|
||||||
|
fit,
|
||||||
|
init_image_path,
|
||||||
|
mask_image_path,
|
||||||
|
orig_path,
|
||||||
|
scale,
|
||||||
|
} = metadata;
|
||||||
|
|
||||||
|
const metadataJSON = JSON.stringify(metadata, null, 2);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
gap={1}
|
||||||
|
direction={'column'}
|
||||||
|
overflowY={'scroll'}
|
||||||
|
width={'100%'}
|
||||||
|
>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Text fontWeight={'semibold'}>File:</Text>
|
||||||
|
<Link href={image.url} isExternal>
|
||||||
|
{image.url}
|
||||||
|
<ExternalLinkIcon mx="2px" />
|
||||||
|
</Link>
|
||||||
|
</Flex>
|
||||||
|
{Object.keys(metadata).length ? (
|
||||||
|
<>
|
||||||
|
{type && <MetadataItem label="Type" value={type} />}
|
||||||
|
{['esrgan', 'gfpgan'].includes(type) && (
|
||||||
|
<MetadataItem label="Original image" value={orig_path} isLink />
|
||||||
|
)}
|
||||||
|
{type === 'gfpgan' && strength && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Fix faces strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => dispatch(setGfpganStrength(strength))}
|
||||||
/>
|
/>
|
||||||
<Flex gap={2}>
|
)}
|
||||||
<Text fontWeight={'semibold'}>File:</Text>
|
{type === 'esrgan' && scale && (
|
||||||
<Link href={image.url} isExternal>
|
<MetadataItem
|
||||||
<Text>{image.url}</Text>
|
label="Upscaling scale"
|
||||||
</Link>
|
value={scale}
|
||||||
</Flex>
|
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||||
{metadata.length ? (
|
/>
|
||||||
<>
|
)}
|
||||||
<List>
|
{type === 'esrgan' && strength && (
|
||||||
{metadata.map((parameter, i) => {
|
<MetadataItem
|
||||||
const { label, key, value } = parameter;
|
label="Upscaling strength"
|
||||||
return (
|
value={strength}
|
||||||
<ListItem key={i} pb={1}>
|
onClick={() => dispatch(setUpscalingStrength(strength))}
|
||||||
<Flex gap={2}>
|
/>
|
||||||
<IconButton
|
)}
|
||||||
aria-label='Use this parameter'
|
{prompt && (
|
||||||
icon={<FaPlus />}
|
<MetadataItem
|
||||||
size={'xs'}
|
label="Prompt"
|
||||||
onClick={() =>
|
value={promptToString(prompt)}
|
||||||
dispatch(
|
onClick={() => dispatch(setPrompt(prompt))}
|
||||||
setParameter({
|
/>
|
||||||
key,
|
)}
|
||||||
value,
|
{seed && (
|
||||||
})
|
<MetadataItem
|
||||||
)
|
label="Seed"
|
||||||
}
|
value={seed}
|
||||||
/>
|
onClick={() => dispatch(setSeed(seed))}
|
||||||
<Text fontWeight={'semibold'}>
|
/>
|
||||||
{label}:
|
)}
|
||||||
</Text>
|
{sampler && (
|
||||||
|
<MetadataItem
|
||||||
{value === undefined ||
|
label="Sampler"
|
||||||
value === null ||
|
value={sampler}
|
||||||
value === '' ||
|
onClick={() => dispatch(setSampler(sampler))}
|
||||||
value === 0 ? (
|
/>
|
||||||
<Text
|
)}
|
||||||
maxHeight={100}
|
{steps && (
|
||||||
fontStyle={'italic'}
|
<MetadataItem
|
||||||
>
|
label="Steps"
|
||||||
None
|
value={steps}
|
||||||
</Text>
|
onClick={() => dispatch(setSteps(steps))}
|
||||||
) : (
|
/>
|
||||||
<Text
|
)}
|
||||||
maxHeight={100}
|
{cfg_scale && (
|
||||||
overflowY={'scroll'}
|
<MetadataItem
|
||||||
>
|
label="CFG scale"
|
||||||
{value.toString()}
|
value={cfg_scale}
|
||||||
</Text>
|
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
||||||
)}
|
/>
|
||||||
</Flex>
|
)}
|
||||||
</ListItem>
|
{variations && variations.length > 0 && (
|
||||||
);
|
<MetadataItem
|
||||||
})}
|
label="Seed-weight pairs"
|
||||||
</List>
|
value={seedWeightsToString(variations)}
|
||||||
<Flex gap={2}>
|
onClick={() =>
|
||||||
<Text fontWeight={'semibold'}>Raw:</Text>
|
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
||||||
<Text
|
}
|
||||||
maxHeight={100}
|
/>
|
||||||
overflowY={'scroll'}
|
)}
|
||||||
wordBreak={'break-all'}
|
{seamless && (
|
||||||
>
|
<MetadataItem
|
||||||
{JSON.stringify(image.metadata)}
|
label="Seamless"
|
||||||
</Text>
|
value={seamless}
|
||||||
</Flex>
|
onClick={() => dispatch(setWidth(seamless))}
|
||||||
</>
|
/>
|
||||||
) : (
|
)}
|
||||||
<Center width={'100%'} pt={10}>
|
{width && (
|
||||||
<Text fontSize={'lg'} fontWeight='semibold'>
|
<MetadataItem
|
||||||
No metadata available
|
label="Width"
|
||||||
</Text>
|
value={width}
|
||||||
</Center>
|
onClick={() => dispatch(setWidth(width))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{height && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Height"
|
||||||
|
value={height}
|
||||||
|
onClick={() => dispatch(setHeight(height))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{init_image_path && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Initial image"
|
||||||
|
value={init_image_path}
|
||||||
|
isLink
|
||||||
|
onClick={() => dispatch(setInitialImagePath(init_image_path))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{mask_image_path && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Mask image"
|
||||||
|
value={mask_image_path}
|
||||||
|
isLink
|
||||||
|
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{type === 'img2img' && strength && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Image to image strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => dispatch(setImg2imgStrength(strength))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{fit && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Image to image fit"
|
||||||
|
value={fit}
|
||||||
|
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{postprocessing &&
|
||||||
|
postprocessing.length > 0 &&
|
||||||
|
postprocessing.map(
|
||||||
|
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
|
||||||
|
if (postprocess.type === 'esrgan') {
|
||||||
|
const { scale, strength } = postprocess;
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<MetadataItem
|
||||||
|
label="Upscaling scale"
|
||||||
|
value={scale}
|
||||||
|
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||||
|
/>
|
||||||
|
<MetadataItem
|
||||||
|
label="Upscaling strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => dispatch(setUpscalingStrength(strength))}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
} else if (postprocess.type === 'gfpgan') {
|
||||||
|
const { strength } = postprocess;
|
||||||
|
return (
|
||||||
|
<MetadataItem
|
||||||
|
label="Fix faces strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => dispatch(setGfpganStrength(strength))}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
)}
|
)}
|
||||||
</Flex>
|
<Flex gap={2} direction={'column'}>
|
||||||
);
|
<Flex gap={2}>
|
||||||
};
|
<Tooltip label={`Copy JSON`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label="Copy JSON"
|
||||||
|
icon={<FaCopy />}
|
||||||
|
size={'xs'}
|
||||||
|
variant={'ghost'}
|
||||||
|
fontSize={14}
|
||||||
|
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Text fontWeight={'semibold'}>JSON:</Text>
|
||||||
|
</Flex>
|
||||||
|
<Box
|
||||||
|
// maxHeight={200}
|
||||||
|
overflow={'scroll'}
|
||||||
|
flexGrow={3}
|
||||||
|
wordBreak={'break-all'}
|
||||||
|
bgColor={jsonBgColor}
|
||||||
|
padding={2}
|
||||||
|
>
|
||||||
|
<pre>{metadataJSON}</pre>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Center width={'100%'} pt={10}>
|
||||||
|
<Text fontSize={'lg'} fontWeight="semibold">
|
||||||
|
No metadata available
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}, memoEqualityCheck);
|
||||||
|
|
||||||
export default ImageMetadataViewer;
|
export default ImageMetadataViewer;
|
||||||
|
@ -1,150 +0,0 @@
|
|||||||
import {
|
|
||||||
Box,
|
|
||||||
Flex,
|
|
||||||
Icon,
|
|
||||||
IconButton,
|
|
||||||
Image,
|
|
||||||
useColorModeValue,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { SDImage, setCurrentImage } from './gallerySlice';
|
|
||||||
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
|
|
||||||
import DeleteImageModalButton from './DeleteImageModalButton';
|
|
||||||
import { memo, SyntheticEvent, useState } from 'react';
|
|
||||||
import { setAllParameters, setSeed } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
interface HoverableImageProps {
|
|
||||||
image: SDImage;
|
|
||||||
isSelected: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
const HoverableImage = memo(
|
|
||||||
(props: HoverableImageProps) => {
|
|
||||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const checkColor = useColorModeValue('green.600', 'green.300');
|
|
||||||
const bgColor = useColorModeValue('gray.200', 'gray.700');
|
|
||||||
const bgGradient = useColorModeValue(
|
|
||||||
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
|
|
||||||
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
|
|
||||||
);
|
|
||||||
|
|
||||||
const { image, isSelected } = props;
|
|
||||||
const { url, uuid, metadata } = image;
|
|
||||||
|
|
||||||
const handleMouseOver = () => setIsHovered(true);
|
|
||||||
const handleMouseOut = () => setIsHovered(false);
|
|
||||||
const handleClickSetAllParameters = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
dispatch(setAllParameters(metadata));
|
|
||||||
};
|
|
||||||
const handleClickSetSeed = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box position={'relative'} key={uuid}>
|
|
||||||
<Image
|
|
||||||
width={120}
|
|
||||||
height={120}
|
|
||||||
objectFit='cover'
|
|
||||||
rounded={'md'}
|
|
||||||
src={url}
|
|
||||||
loading={'lazy'}
|
|
||||||
backgroundColor={bgColor}
|
|
||||||
/>
|
|
||||||
<Flex
|
|
||||||
cursor={'pointer'}
|
|
||||||
position={'absolute'}
|
|
||||||
top={0}
|
|
||||||
left={0}
|
|
||||||
rounded={'md'}
|
|
||||||
width='100%'
|
|
||||||
height='100%'
|
|
||||||
alignItems={'center'}
|
|
||||||
justifyContent={'center'}
|
|
||||||
background={isSelected ? bgGradient : undefined}
|
|
||||||
onClick={() => dispatch(setCurrentImage(image))}
|
|
||||||
onMouseOver={handleMouseOver}
|
|
||||||
onMouseOut={handleMouseOut}
|
|
||||||
>
|
|
||||||
{isSelected && (
|
|
||||||
<Icon
|
|
||||||
fill={checkColor}
|
|
||||||
width={'50%'}
|
|
||||||
height={'50%'}
|
|
||||||
as={FaCheck}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{isHovered && (
|
|
||||||
<Flex
|
|
||||||
direction={'column'}
|
|
||||||
gap={1}
|
|
||||||
position={'absolute'}
|
|
||||||
top={1}
|
|
||||||
right={1}
|
|
||||||
>
|
|
||||||
<DeleteImageModalButton image={image}>
|
|
||||||
<IconButton
|
|
||||||
colorScheme='red'
|
|
||||||
aria-label='Delete image'
|
|
||||||
icon={<FaTrash />}
|
|
||||||
size='xs'
|
|
||||||
fontSize={15}
|
|
||||||
/>
|
|
||||||
</DeleteImageModalButton>
|
|
||||||
<IconButton
|
|
||||||
aria-label='Use all parameters'
|
|
||||||
colorScheme={'blue'}
|
|
||||||
icon={<FaCopy />}
|
|
||||||
size='xs'
|
|
||||||
fontSize={15}
|
|
||||||
onClickCapture={handleClickSetAllParameters}
|
|
||||||
/>
|
|
||||||
{image.metadata.seed && (
|
|
||||||
<IconButton
|
|
||||||
aria-label='Use seed'
|
|
||||||
colorScheme={'blue'}
|
|
||||||
icon={<FaSeedling />}
|
|
||||||
size='xs'
|
|
||||||
fontSize={16}
|
|
||||||
onClickCapture={handleClickSetSeed}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
},
|
|
||||||
(prev, next) =>
|
|
||||||
prev.image.uuid === next.image.uuid &&
|
|
||||||
prev.isSelected === next.isSelected
|
|
||||||
);
|
|
||||||
|
|
||||||
const ImageRoll = () => {
|
|
||||||
const { images, currentImageUuid } = useAppSelector(
|
|
||||||
(state: RootState) => state.gallery
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} wrap='wrap' pb={2}>
|
|
||||||
{[...images].reverse().map((image) => {
|
|
||||||
const { uuid } = image;
|
|
||||||
const isSelected = currentImageUuid === uuid;
|
|
||||||
return (
|
|
||||||
<HoverableImage
|
|
||||||
key={uuid}
|
|
||||||
image={image}
|
|
||||||
isSelected={isSelected}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ImageRoll;
|
|
@ -1,40 +1,13 @@
|
|||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { clamp } from 'lodash';
|
||||||
import { UpscalingLevel } from '../sd/sdSlice';
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
import { backendToFrontendParameters } from '../../app/parameterTranslation';
|
|
||||||
|
|
||||||
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
|
||||||
export interface SDMetadata {
|
|
||||||
prompt?: string;
|
|
||||||
steps?: number;
|
|
||||||
cfgScale?: number;
|
|
||||||
height?: number;
|
|
||||||
width?: number;
|
|
||||||
sampler?: string;
|
|
||||||
seed?: number;
|
|
||||||
img2imgStrength?: number;
|
|
||||||
gfpganStrength?: number;
|
|
||||||
upscalingLevel?: UpscalingLevel;
|
|
||||||
upscalingStrength?: number;
|
|
||||||
initialImagePath?: string;
|
|
||||||
maskPath?: string;
|
|
||||||
seamless?: boolean;
|
|
||||||
shouldFitToWidthHeight?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface SDImage {
|
|
||||||
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
|
|
||||||
uuid: string;
|
|
||||||
url: string;
|
|
||||||
metadata: SDMetadata;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface GalleryState {
|
export interface GalleryState {
|
||||||
|
currentImage?: InvokeAI.Image;
|
||||||
currentImageUuid: string;
|
currentImageUuid: string;
|
||||||
images: Array<SDImage>;
|
images: Array<InvokeAI.Image>;
|
||||||
intermediateImage?: SDImage;
|
intermediateImage?: InvokeAI.Image;
|
||||||
currentImage?: SDImage;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialState: GalleryState = {
|
const initialState: GalleryState = {
|
||||||
@ -46,99 +19,84 @@ export const gallerySlice = createSlice({
|
|||||||
name: 'gallery',
|
name: 'gallery',
|
||||||
initialState,
|
initialState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setCurrentImage: (state, action: PayloadAction<SDImage>) => {
|
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||||
state.currentImage = action.payload;
|
state.currentImage = action.payload;
|
||||||
state.currentImageUuid = action.payload.uuid;
|
state.currentImageUuid = action.payload.uuid;
|
||||||
},
|
},
|
||||||
removeImage: (state, action: PayloadAction<SDImage>) => {
|
removeImage: (state, action: PayloadAction<string>) => {
|
||||||
const { uuid } = action.payload;
|
const uuid = action.payload;
|
||||||
|
|
||||||
const newImages = state.images.filter((image) => image.uuid !== uuid);
|
const newImages = state.images.filter((image) => image.uuid !== uuid);
|
||||||
|
|
||||||
const imageToDeleteIndex = state.images.findIndex(
|
if (uuid === state.currentImageUuid) {
|
||||||
(image) => image.uuid === uuid
|
/**
|
||||||
);
|
* We are deleting the currently selected image.
|
||||||
|
*
|
||||||
|
* We want the new currentl selected image to be under the cursor in the
|
||||||
|
* gallery, so we need to do some fanagling. The currently selected image
|
||||||
|
* is set by its UUID, not its index in the image list.
|
||||||
|
*
|
||||||
|
* Get the currently selected image's index.
|
||||||
|
*/
|
||||||
|
const imageToDeleteIndex = state.images.findIndex(
|
||||||
|
(image) => image.uuid === uuid
|
||||||
|
);
|
||||||
|
|
||||||
const newCurrentImageIndex = Math.min(
|
/**
|
||||||
Math.max(imageToDeleteIndex, 0),
|
* New current image needs to be in the same spot, but because the gallery
|
||||||
newImages.length - 1
|
* is sorted in reverse order, the new current image's index will actuall be
|
||||||
);
|
* one less than the deleted image's index.
|
||||||
|
*
|
||||||
|
* Clamp the new index to ensure it is valid..
|
||||||
|
*/
|
||||||
|
const newCurrentImageIndex = clamp(
|
||||||
|
imageToDeleteIndex - 1,
|
||||||
|
0,
|
||||||
|
newImages.length - 1
|
||||||
|
);
|
||||||
|
|
||||||
|
state.currentImage = newImages.length
|
||||||
|
? newImages[newCurrentImageIndex]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
state.currentImageUuid = newImages.length
|
||||||
|
? newImages[newCurrentImageIndex].uuid
|
||||||
|
: '';
|
||||||
|
}
|
||||||
|
|
||||||
state.images = newImages;
|
state.images = newImages;
|
||||||
|
|
||||||
state.currentImage = newImages.length
|
|
||||||
? newImages[newCurrentImageIndex]
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
state.currentImageUuid = newImages.length
|
|
||||||
? newImages[newCurrentImageIndex].uuid
|
|
||||||
: '';
|
|
||||||
},
|
},
|
||||||
addImage: (state, action: PayloadAction<SDImage>) => {
|
addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||||
state.images.push(action.payload);
|
state.images.push(action.payload);
|
||||||
state.currentImageUuid = action.payload.uuid;
|
state.currentImageUuid = action.payload.uuid;
|
||||||
state.intermediateImage = undefined;
|
state.intermediateImage = undefined;
|
||||||
state.currentImage = action.payload;
|
state.currentImage = action.payload;
|
||||||
},
|
},
|
||||||
setIntermediateImage: (state, action: PayloadAction<SDImage>) => {
|
setIntermediateImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||||
state.intermediateImage = action.payload;
|
state.intermediateImage = action.payload;
|
||||||
},
|
},
|
||||||
clearIntermediateImage: (state) => {
|
clearIntermediateImage: (state) => {
|
||||||
state.intermediateImage = undefined;
|
state.intermediateImage = undefined;
|
||||||
},
|
},
|
||||||
setGalleryImages: (
|
setGalleryImages: (state, action: PayloadAction<Array<InvokeAI.Image>>) => {
|
||||||
state,
|
const newImages = action.payload;
|
||||||
action: PayloadAction<
|
if (newImages.length) {
|
||||||
Array<{
|
const newCurrentImage = newImages[newImages.length - 1];
|
||||||
path: string;
|
|
||||||
metadata: { [key: string]: string | number | boolean };
|
|
||||||
}>
|
|
||||||
>
|
|
||||||
) => {
|
|
||||||
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
|
|
||||||
const images = action.payload;
|
|
||||||
|
|
||||||
if (images.length === 0) {
|
|
||||||
// there are no images on disk, clear the gallery
|
|
||||||
state.images = [];
|
|
||||||
state.currentImageUuid = '';
|
|
||||||
state.currentImage = undefined;
|
|
||||||
} else {
|
|
||||||
// Filter image urls that are already in the rehydrated state
|
|
||||||
const filteredImages = action.payload.filter(
|
|
||||||
(image) => !state.images.find((i) => i.url === image.path)
|
|
||||||
);
|
|
||||||
|
|
||||||
const preparedImages = filteredImages.map((image): SDImage => {
|
|
||||||
return {
|
|
||||||
uuid: uuidv4(),
|
|
||||||
url: image.path,
|
|
||||||
metadata: backendToFrontendParameters(image.metadata),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const newImages = [...state.images].concat(preparedImages);
|
|
||||||
|
|
||||||
// if previous currentimage no longer exists, set a new one
|
|
||||||
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
|
|
||||||
const newCurrentImage = newImages[newImages.length - 1];
|
|
||||||
state.currentImage = newCurrentImage;
|
|
||||||
state.currentImageUuid = newCurrentImage.uuid;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.images = newImages;
|
state.images = newImages;
|
||||||
|
state.currentImage = newCurrentImage;
|
||||||
|
state.currentImageUuid = newCurrentImage.uuid;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
setCurrentImage,
|
|
||||||
removeImage,
|
|
||||||
addImage,
|
addImage,
|
||||||
|
clearIntermediateImage,
|
||||||
|
removeImage,
|
||||||
|
setCurrentImage,
|
||||||
setGalleryImages,
|
setGalleryImages,
|
||||||
setIntermediateImage,
|
setIntermediateImage,
|
||||||
clearIntermediateImage,
|
|
||||||
} = gallerySlice.actions;
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
export default gallerySlice.reducer;
|
export default gallerySlice.reducer;
|
||||||
|
@ -1,35 +0,0 @@
|
|||||||
import { Progress } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { SDState } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
realSteps: sd.realSteps,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ProgressBar = () => {
|
|
||||||
const { realSteps } = useAppSelector(sdSelector);
|
|
||||||
const { currentStep } = useAppSelector((state: RootState) => state.system);
|
|
||||||
const progress = Math.round((currentStep * 100) / realSteps);
|
|
||||||
return (
|
|
||||||
<Progress
|
|
||||||
height='10px'
|
|
||||||
value={progress}
|
|
||||||
isIndeterminate={progress < 0 || currentStep === realSteps}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ProgressBar;
|
|
@ -1,93 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
Heading,
|
|
||||||
IconButton,
|
|
||||||
Link,
|
|
||||||
Spacer,
|
|
||||||
Text,
|
|
||||||
useColorMode,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
|
|
||||||
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
|
|
||||||
import { MdHelp, MdSettings } from 'react-icons/md';
|
|
||||||
import { useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SettingsModal from '../system/SettingsModal';
|
|
||||||
import { SystemState } from '../system/systemSlice';
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return { isConnected: system.isConnected };
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const SiteHeader = () => {
|
|
||||||
const { colorMode, toggleColorMode } = useColorMode();
|
|
||||||
const { isConnected } = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
|
|
||||||
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
|
|
||||||
|
|
||||||
<Spacer />
|
|
||||||
|
|
||||||
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
|
|
||||||
{isConnected ? `Connected to server` : 'No connection to server'}
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<SettingsModal>
|
|
||||||
<IconButton
|
|
||||||
aria-label='Settings'
|
|
||||||
variant='link'
|
|
||||||
fontSize={24}
|
|
||||||
size={'sm'}
|
|
||||||
icon={<MdSettings />}
|
|
||||||
/>
|
|
||||||
</SettingsModal>
|
|
||||||
|
|
||||||
<IconButton
|
|
||||||
aria-label='Link to Github Issues'
|
|
||||||
variant='link'
|
|
||||||
fontSize={23}
|
|
||||||
size={'sm'}
|
|
||||||
icon={
|
|
||||||
<Link
|
|
||||||
isExternal
|
|
||||||
href='http://github.com/lstein/stable-diffusion/issues'
|
|
||||||
>
|
|
||||||
<MdHelp />
|
|
||||||
</Link>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<IconButton
|
|
||||||
aria-label='Link to Github Repo'
|
|
||||||
variant='link'
|
|
||||||
fontSize={20}
|
|
||||||
size={'sm'}
|
|
||||||
icon={
|
|
||||||
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
|
|
||||||
<FaGithub />
|
|
||||||
</Link>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<IconButton
|
|
||||||
aria-label='Toggle Dark Mode'
|
|
||||||
onClick={toggleColorMode}
|
|
||||||
variant='link'
|
|
||||||
size={'sm'}
|
|
||||||
fontSize={colorMode == 'light' ? 18 : 20}
|
|
||||||
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SiteHeader;
|
|
87
frontend/src/features/options/ESRGANOptions.tsx
Normal file
87
frontend/src/features/options/ESRGANOptions.tsx
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
|
||||||
|
import {
|
||||||
|
setUpscalingLevel,
|
||||||
|
setUpscalingStrength,
|
||||||
|
UpscalingLevel,
|
||||||
|
OptionsState,
|
||||||
|
} from '../options/optionsSlice';
|
||||||
|
|
||||||
|
|
||||||
|
import { UPSCALING_LEVELS } from '../../app/constants';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||||
|
import SDSelect from '../../common/components/SDSelect';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
upscalingLevel: options.upscalingLevel,
|
||||||
|
upscalingStrength: options.upscalingStrength,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Displays upscaling/ESRGAN options (level and strength).
|
||||||
|
*/
|
||||||
|
const ESRGANOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { upscalingLevel, upscalingStrength } = useAppSelector(optionsSelector);
|
||||||
|
const { isESRGANAvailable } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||||
|
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
|
||||||
|
|
||||||
|
const handleChangeStrength = (v: string | number) =>
|
||||||
|
dispatch(setUpscalingStrength(Number(v)));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDSelect
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
label="Scale"
|
||||||
|
value={upscalingLevel}
|
||||||
|
onChange={handleChangeLevel}
|
||||||
|
validValues={UPSCALING_LEVELS}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
label="Strength"
|
||||||
|
step={0.05}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChangeStrength}
|
||||||
|
value={upscalingStrength}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ESRGANOptions;
|
68
frontend/src/features/options/GFPGANOptions.tsx
Normal file
68
frontend/src/features/options/GFPGANOptions.tsx
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
|
||||||
|
import { OptionsState, setGfpganStrength } from '../options/optionsSlice';
|
||||||
|
|
||||||
|
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
gfpganStrength: options.gfpganStrength,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Displays face-fixing/GFPGAN options (strength).
|
||||||
|
*/
|
||||||
|
const GFPGANOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { gfpganStrength } = useAppSelector(optionsSelector);
|
||||||
|
const { isGFPGANAvailable } = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const handleChangeStrength = (v: string | number) =>
|
||||||
|
dispatch(setGfpganStrength(Number(v)));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
isDisabled={!isGFPGANAvailable}
|
||||||
|
label="Strength"
|
||||||
|
step={0.05}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChangeStrength}
|
||||||
|
value={gfpganStrength}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default GFPGANOptions;
|
59
frontend/src/features/options/ImageToImageOptions.tsx
Normal file
59
frontend/src/features/options/ImageToImageOptions.tsx
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||||
|
import SDSwitch from '../../common/components/SDSwitch';
|
||||||
|
import InitAndMaskImage from './InitAndMaskImage';
|
||||||
|
import {
|
||||||
|
OptionsState,
|
||||||
|
setImg2imgStrength,
|
||||||
|
setShouldFitToWidthHeight,
|
||||||
|
} from './optionsSlice';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
img2imgStrength: options.img2imgStrength,
|
||||||
|
shouldFitToWidthHeight: options.shouldFitToWidthHeight,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for img2img generation (strength, fit, init/mask upload).
|
||||||
|
*/
|
||||||
|
const ImageToImageOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { img2imgStrength, shouldFitToWidthHeight } =
|
||||||
|
useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const handleChangeStrength = (v: string | number) =>
|
||||||
|
dispatch(setImg2imgStrength(Number(v)));
|
||||||
|
|
||||||
|
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldFitToWidthHeight(e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
label="Strength"
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChangeStrength}
|
||||||
|
value={img2imgStrength}
|
||||||
|
/>
|
||||||
|
<SDSwitch
|
||||||
|
label="Fit initial image to output size"
|
||||||
|
isChecked={shouldFitToWidthHeight}
|
||||||
|
onChange={handleChangeFit}
|
||||||
|
/>
|
||||||
|
<InitAndMaskImage />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageToImageOptions;
|
64
frontend/src/features/options/ImageUploader.tsx
Normal file
64
frontend/src/features/options/ImageUploader.tsx
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||||
|
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||||
|
|
||||||
|
type ImageUploaderProps = {
|
||||||
|
/**
|
||||||
|
* Component which, on click, should open the upload interface.
|
||||||
|
*/
|
||||||
|
children: ReactElement;
|
||||||
|
/**
|
||||||
|
* Callback to handle uploading the selected file.
|
||||||
|
*/
|
||||||
|
fileAcceptedCallback: (file: File) => void;
|
||||||
|
/**
|
||||||
|
* Callback to handle a file being rejected.
|
||||||
|
*/
|
||||||
|
fileRejectionCallback: (rejection: FileRejection) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* File upload using react-dropzone.
|
||||||
|
* Needs a child to be the button to activate the upload interface.
|
||||||
|
*/
|
||||||
|
const ImageUploader = ({
|
||||||
|
children,
|
||||||
|
fileAcceptedCallback,
|
||||||
|
fileRejectionCallback,
|
||||||
|
}: ImageUploaderProps) => {
|
||||||
|
const onDrop = useCallback(
|
||||||
|
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||||
|
fileRejections.forEach((rejection: FileRejection) => {
|
||||||
|
fileRejectionCallback(rejection);
|
||||||
|
});
|
||||||
|
|
||||||
|
acceptedFiles.forEach((file: File) => {
|
||||||
|
fileAcceptedCallback(file);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[fileAcceptedCallback, fileRejectionCallback]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { getRootProps, getInputProps, open } = useDropzone({
|
||||||
|
onDrop,
|
||||||
|
accept: {
|
||||||
|
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
open();
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box {...getRootProps()} flexGrow={3}>
|
||||||
|
<input {...getInputProps({ multiple: false })} />
|
||||||
|
{cloneElement(children, {
|
||||||
|
onClick: handleClickUploadIcon,
|
||||||
|
})}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageUploader;
|
57
frontend/src/features/options/InitAndMaskImage.tsx
Normal file
57
frontend/src/features/options/InitAndMaskImage.tsx
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import { Flex, Image } from '@chakra-ui/react';
|
||||||
|
import { useState } from 'react';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { OptionsState } from '../../features/options/optionsSlice';
|
||||||
|
import './InitAndMaskImage.css';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: options.initialImagePath,
|
||||||
|
maskPath: options.maskPath,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Displays init and mask images and buttons to upload/delete them.
|
||||||
|
*/
|
||||||
|
const InitAndMaskImage = () => {
|
||||||
|
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
|
||||||
|
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex direction={'column'} alignItems={'center'} gap={2}>
|
||||||
|
<InitAndMaskUploadButtons setShouldShowMask={setShouldShowMask} />
|
||||||
|
{initialImagePath && (
|
||||||
|
<Flex position={'relative'} width={'100%'}>
|
||||||
|
<Image
|
||||||
|
fit={'contain'}
|
||||||
|
src={initialImagePath}
|
||||||
|
rounded={'md'}
|
||||||
|
className={'checkerboard'}
|
||||||
|
/>
|
||||||
|
{shouldShowMask && maskPath && (
|
||||||
|
<Image
|
||||||
|
position={'absolute'}
|
||||||
|
top={0}
|
||||||
|
left={0}
|
||||||
|
fit={'contain'}
|
||||||
|
src={maskPath}
|
||||||
|
rounded={'md'}
|
||||||
|
zIndex={1}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default InitAndMaskImage;
|
151
frontend/src/features/options/InitAndMaskUploadButtons.tsx
Normal file
151
frontend/src/features/options/InitAndMaskUploadButtons.tsx
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
|
||||||
|
import { SyntheticEvent, useCallback } from 'react';
|
||||||
|
import { FaTrash, FaUpload } from 'react-icons/fa';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import {
|
||||||
|
OptionsState,
|
||||||
|
setInitialImagePath,
|
||||||
|
setMaskPath,
|
||||||
|
} from '../../features/options/optionsSlice';
|
||||||
|
import {
|
||||||
|
uploadInitialImage,
|
||||||
|
uploadMaskImage,
|
||||||
|
} from '../../app/socketio/actions';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import ImageUploader from './ImageUploader';
|
||||||
|
import { FileRejection } from 'react-dropzone';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: options.initialImagePath,
|
||||||
|
maskPath: options.maskPath,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
||||||
|
);
|
||||||
|
|
||||||
|
type InitAndMaskUploadButtonsProps = {
|
||||||
|
setShouldShowMask: (b: boolean) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Init and mask image upload buttons.
|
||||||
|
*/
|
||||||
|
const InitAndMaskUploadButtons = ({
|
||||||
|
setShouldShowMask,
|
||||||
|
}: InitAndMaskUploadButtonsProps) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
// Use a toast to alert user when a file upload is rejected
|
||||||
|
const toast = useToast();
|
||||||
|
|
||||||
|
// Clear the init and mask images
|
||||||
|
const handleClickResetInitialImage = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setInitialImagePath(''));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Clear the init and mask images
|
||||||
|
const handleClickResetMask = (e: SyntheticEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
dispatch(setMaskPath(''));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle hover to view initial image and mask image
|
||||||
|
const handleMouseOverInitialImageUploadButton = () =>
|
||||||
|
setShouldShowMask(false);
|
||||||
|
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
|
||||||
|
|
||||||
|
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
|
||||||
|
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
|
||||||
|
|
||||||
|
// Callbacks to for handling file upload attempts
|
||||||
|
const initImageFileAcceptedCallback = useCallback(
|
||||||
|
(file: File) => dispatch(uploadInitialImage(file)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const maskImageFileAcceptedCallback = useCallback(
|
||||||
|
(file: File) => dispatch(uploadMaskImage(file)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const fileRejectionCallback = useCallback(
|
||||||
|
(rejection: FileRejection) => {
|
||||||
|
const msg = rejection.errors.reduce(
|
||||||
|
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
|
||||||
|
''
|
||||||
|
);
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: 'Upload failed',
|
||||||
|
description: msg,
|
||||||
|
status: 'error',
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[toast]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
|
||||||
|
<ImageUploader
|
||||||
|
fileAcceptedCallback={initImageFileAcceptedCallback}
|
||||||
|
fileRejectionCallback={fileRejectionCallback}
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={'md'}
|
||||||
|
fontWeight={'normal'}
|
||||||
|
onMouseOver={handleMouseOverInitialImageUploadButton}
|
||||||
|
onMouseOut={handleMouseOutInitialImageUploadButton}
|
||||||
|
leftIcon={<FaUpload />}
|
||||||
|
width={'100%'}
|
||||||
|
>
|
||||||
|
Image
|
||||||
|
</Button>
|
||||||
|
</ImageUploader>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
size={'sm'}
|
||||||
|
aria-label={'Reset mask'}
|
||||||
|
onClick={handleClickResetInitialImage}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<ImageUploader
|
||||||
|
fileAcceptedCallback={maskImageFileAcceptedCallback}
|
||||||
|
fileRejectionCallback={fileRejectionCallback}
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={'md'}
|
||||||
|
fontWeight={'normal'}
|
||||||
|
onMouseOver={handleMouseOverMaskUploadButton}
|
||||||
|
onMouseOut={handleMouseOutMaskUploadButton}
|
||||||
|
leftIcon={<FaUpload />}
|
||||||
|
width={'100%'}
|
||||||
|
>
|
||||||
|
Mask
|
||||||
|
</Button>
|
||||||
|
</ImageUploader>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
isDisabled={!maskPath}
|
||||||
|
size={'sm'}
|
||||||
|
aria-label={'Reset mask'}
|
||||||
|
onClick={handleClickResetMask}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default InitAndMaskUploadButtons;
|
217
frontend/src/features/options/OptionsAccordion.tsx
Normal file
217
frontend/src/features/options/OptionsAccordion.tsx
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Box,
|
||||||
|
Text,
|
||||||
|
Accordion,
|
||||||
|
AccordionItem,
|
||||||
|
AccordionButton,
|
||||||
|
AccordionIcon,
|
||||||
|
AccordionPanel,
|
||||||
|
Switch,
|
||||||
|
ExpandedIndex,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
|
||||||
|
import {
|
||||||
|
setShouldRunGFPGAN,
|
||||||
|
setShouldRunESRGAN,
|
||||||
|
OptionsState,
|
||||||
|
setShouldUseInitImage,
|
||||||
|
} from '../options/optionsSlice';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { setOpenAccordions, SystemState } from '../system/systemSlice';
|
||||||
|
import SeedVariationOptions from './SeedVariationOptions';
|
||||||
|
import SamplerOptions from './SamplerOptions';
|
||||||
|
import ESRGANOptions from './ESRGANOptions';
|
||||||
|
import GFPGANOptions from './GFPGANOptions';
|
||||||
|
import OutputOptions from './OutputOptions';
|
||||||
|
import ImageToImageOptions from './ImageToImageOptions';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
initialImagePath: options.initialImagePath,
|
||||||
|
shouldUseInitImage: options.shouldUseInitImage,
|
||||||
|
shouldRunESRGAN: options.shouldRunESRGAN,
|
||||||
|
shouldRunGFPGAN: options.shouldRunGFPGAN,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isGFPGANAvailable: system.isGFPGANAvailable,
|
||||||
|
isESRGANAvailable: system.isESRGANAvailable,
|
||||||
|
openAccordions: system.openAccordions,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Main container for generation and processing parameters.
|
||||||
|
*/
|
||||||
|
const OptionsAccordion = () => {
|
||||||
|
const {
|
||||||
|
shouldRunESRGAN,
|
||||||
|
shouldRunGFPGAN,
|
||||||
|
shouldUseInitImage,
|
||||||
|
initialImagePath,
|
||||||
|
} = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores accordion state in redux so preferred UI setup is retained.
|
||||||
|
*/
|
||||||
|
const handleChangeAccordionState = (openAccordions: ExpandedIndex) =>
|
||||||
|
dispatch(setOpenAccordions(openAccordions));
|
||||||
|
|
||||||
|
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldRunESRGAN(e.target.checked));
|
||||||
|
|
||||||
|
const handleChangeShouldRunGFPGAN = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldRunGFPGAN(e.target.checked));
|
||||||
|
|
||||||
|
const handleChangeShouldUseInitImage = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldUseInitImage(e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Accordion
|
||||||
|
defaultIndex={openAccordions}
|
||||||
|
allowMultiple
|
||||||
|
reduceMotion
|
||||||
|
onChange={handleChangeAccordionState}
|
||||||
|
>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex="1" textAlign="left">
|
||||||
|
Seed & Variation
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<SeedVariationOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex="1" textAlign="left">
|
||||||
|
Sampler
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<SamplerOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Upscale (ESRGAN)</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!isESRGANAvailable}
|
||||||
|
isChecked={shouldRunESRGAN}
|
||||||
|
onChange={handleChangeShouldRunESRGAN}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<ESRGANOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Fix Faces (GFPGAN)</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!isGFPGANAvailable}
|
||||||
|
isChecked={shouldRunGFPGAN}
|
||||||
|
onChange={handleChangeShouldRunGFPGAN}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<GFPGANOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Flex
|
||||||
|
justifyContent={'space-between'}
|
||||||
|
alignItems={'center'}
|
||||||
|
width={'100%'}
|
||||||
|
mr={2}
|
||||||
|
>
|
||||||
|
<Text>Image to Image</Text>
|
||||||
|
<Switch
|
||||||
|
isDisabled={!initialImagePath}
|
||||||
|
isChecked={shouldUseInitImage}
|
||||||
|
onChange={handleChangeShouldUseInitImage}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<ImageToImageOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
<AccordionItem>
|
||||||
|
<h2>
|
||||||
|
<AccordionButton>
|
||||||
|
<Box flex="1" textAlign="left">
|
||||||
|
Output
|
||||||
|
</Box>
|
||||||
|
<AccordionIcon />
|
||||||
|
</AccordionButton>
|
||||||
|
</h2>
|
||||||
|
<AccordionPanel>
|
||||||
|
<OutputOptions />
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
</Accordion>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OptionsAccordion;
|
76
frontend/src/features/options/OutputOptions.tsx
Normal file
76
frontend/src/features/options/OutputOptions.tsx
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
|
||||||
|
import { setHeight, setWidth, setSeamless, OptionsState } from '../options/optionsSlice';
|
||||||
|
|
||||||
|
|
||||||
|
import { HEIGHTS, WIDTHS } from '../../app/constants';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import SDSelect from '../../common/components/SDSelect';
|
||||||
|
import SDSwitch from '../../common/components/SDSwitch';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
height: options.height,
|
||||||
|
width: options.width,
|
||||||
|
seamless: options.seamless,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Image output options. Includes width, height, seamless tiling.
|
||||||
|
*/
|
||||||
|
const OutputOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { height, width, seamless } = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||||
|
dispatch(setWidth(Number(e.target.value)));
|
||||||
|
|
||||||
|
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||||
|
dispatch(setHeight(Number(e.target.value)));
|
||||||
|
|
||||||
|
const handleChangeSeamless = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setSeamless(e.target.checked));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDSelect
|
||||||
|
label="Width"
|
||||||
|
value={width}
|
||||||
|
flexGrow={1}
|
||||||
|
onChange={handleChangeWidth}
|
||||||
|
validValues={WIDTHS}
|
||||||
|
/>
|
||||||
|
<SDSelect
|
||||||
|
label="Height"
|
||||||
|
value={height}
|
||||||
|
flexGrow={1}
|
||||||
|
onChange={handleChangeHeight}
|
||||||
|
validValues={HEIGHTS}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<SDSwitch
|
||||||
|
label="Seamless tiling"
|
||||||
|
fontSize={'md'}
|
||||||
|
isChecked={seamless}
|
||||||
|
onChange={handleChangeSeamless}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OutputOptions;
|
68
frontend/src/features/options/ProcessButtons.tsx
Normal file
68
frontend/src/features/options/ProcessButtons.tsx
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDButton from '../../common/components/SDButton';
|
||||||
|
import useCheckParameters from '../../common/hooks/useCheckParameters';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Buttons to start and cancel image generation.
|
||||||
|
*/
|
||||||
|
const ProcessButtons = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
||||||
|
const isReady = useCheckParameters();
|
||||||
|
|
||||||
|
const handleClickGenerate = () => dispatch(generateImage());
|
||||||
|
|
||||||
|
const handleClickCancel = () => dispatch(cancelProcessing());
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
gap={2}
|
||||||
|
direction={'column'}
|
||||||
|
alignItems={'space-between'}
|
||||||
|
height={'100%'}
|
||||||
|
>
|
||||||
|
<SDButton
|
||||||
|
label="Generate"
|
||||||
|
type="submit"
|
||||||
|
colorScheme="green"
|
||||||
|
flexGrow={1}
|
||||||
|
isDisabled={!isReady}
|
||||||
|
fontSize={'md'}
|
||||||
|
size={'md'}
|
||||||
|
onClick={handleClickGenerate}
|
||||||
|
/>
|
||||||
|
<SDButton
|
||||||
|
label="Cancel"
|
||||||
|
colorScheme="red"
|
||||||
|
flexGrow={1}
|
||||||
|
fontSize={'md'}
|
||||||
|
size={'md'}
|
||||||
|
isDisabled={!isConnected || !isProcessing}
|
||||||
|
onClick={handleClickCancel}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ProcessButtons;
|
44
frontend/src/features/options/PromptInput.tsx
Normal file
44
frontend/src/features/options/PromptInput.tsx
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import { Textarea } from '@chakra-ui/react';
|
||||||
|
import {
|
||||||
|
ChangeEvent,
|
||||||
|
KeyboardEvent,
|
||||||
|
} from 'react';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { generateImage } from '../../app/socketio/actions';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { setPrompt } from '../options/optionsSlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prompt input text area.
|
||||||
|
*/
|
||||||
|
const PromptInput = () => {
|
||||||
|
const { prompt } = useAppSelector((state: RootState) => state.options);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>
|
||||||
|
dispatch(setPrompt(e.target.value));
|
||||||
|
|
||||||
|
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (e.key === 'Enter' && e.shiftKey === false) {
|
||||||
|
e.preventDefault();
|
||||||
|
dispatch(generateImage())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
resize="none"
|
||||||
|
size={'lg'}
|
||||||
|
height={'100%'}
|
||||||
|
isInvalid={!prompt.length}
|
||||||
|
onChange={handleChangePrompt}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
value={prompt}
|
||||||
|
placeholder="I'm dreaming of..."
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default PromptInput;
|
74
frontend/src/features/options/SamplerOptions.tsx
Normal file
74
frontend/src/features/options/SamplerOptions.tsx
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
|
||||||
|
import { setCfgScale, setSampler, setSteps, OptionsState } from '../options/optionsSlice';
|
||||||
|
|
||||||
|
|
||||||
|
import { SAMPLERS } from '../../app/constants';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||||
|
import SDSelect from '../../common/components/SDSelect';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
steps: options.steps,
|
||||||
|
cfgScale: options.cfgScale,
|
||||||
|
sampler: options.sampler,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sampler options. Includes steps, CFG scale, sampler.
|
||||||
|
*/
|
||||||
|
const SamplerOptions = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { steps, cfgScale, sampler } = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const handleChangeSteps = (v: string | number) =>
|
||||||
|
dispatch(setSteps(Number(v)));
|
||||||
|
|
||||||
|
const handleChangeCfgScale = (v: string | number) =>
|
||||||
|
dispatch(setCfgScale(Number(v)));
|
||||||
|
|
||||||
|
const handleChangeSampler = (e: ChangeEvent<HTMLSelectElement>) =>
|
||||||
|
dispatch(setSampler(e.target.value));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<SDNumberInput
|
||||||
|
label="Steps"
|
||||||
|
min={1}
|
||||||
|
step={1}
|
||||||
|
precision={0}
|
||||||
|
onChange={handleChangeSteps}
|
||||||
|
value={steps}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
label="CFG scale"
|
||||||
|
step={0.5}
|
||||||
|
onChange={handleChangeCfgScale}
|
||||||
|
value={cfgScale}
|
||||||
|
/>
|
||||||
|
<SDSelect
|
||||||
|
label="Sampler"
|
||||||
|
value={sampler}
|
||||||
|
onChange={handleChangeSampler}
|
||||||
|
validValues={SAMPLERS}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SamplerOptions;
|
159
frontend/src/features/options/SeedVariationOptions.tsx
Normal file
159
frontend/src/features/options/SeedVariationOptions.tsx
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Input,
|
||||||
|
HStack,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Text,
|
||||||
|
Button,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SDNumberInput from '../../common/components/SDNumberInput';
|
||||||
|
import SDSwitch from '../../common/components/SDSwitch';
|
||||||
|
import randomInt from '../../common/util/randomInt';
|
||||||
|
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
|
||||||
|
import {
|
||||||
|
OptionsState,
|
||||||
|
setIterations,
|
||||||
|
setSeed,
|
||||||
|
setSeedWeights,
|
||||||
|
setShouldGenerateVariations,
|
||||||
|
setShouldRandomizeSeed,
|
||||||
|
setVariationAmount,
|
||||||
|
} from './optionsSlice';
|
||||||
|
|
||||||
|
const optionsSelector = createSelector(
|
||||||
|
(state: RootState) => state.options,
|
||||||
|
(options: OptionsState) => {
|
||||||
|
return {
|
||||||
|
variationAmount: options.variationAmount,
|
||||||
|
seedWeights: options.seedWeights,
|
||||||
|
shouldGenerateVariations: options.shouldGenerateVariations,
|
||||||
|
shouldRandomizeSeed: options.shouldRandomizeSeed,
|
||||||
|
seed: options.seed,
|
||||||
|
iterations: options.iterations,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Seed & variation options. Includes iteration, seed, seed randomization, variation options.
|
||||||
|
*/
|
||||||
|
const SeedVariationOptions = () => {
|
||||||
|
const {
|
||||||
|
shouldGenerateVariations,
|
||||||
|
variationAmount,
|
||||||
|
seedWeights,
|
||||||
|
shouldRandomizeSeed,
|
||||||
|
seed,
|
||||||
|
iterations,
|
||||||
|
} = useAppSelector(optionsSelector);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChangeIterations = (v: string | number) =>
|
||||||
|
dispatch(setIterations(Number(v)));
|
||||||
|
|
||||||
|
const handleChangeShouldRandomizeSeed = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setShouldRandomizeSeed(e.target.checked));
|
||||||
|
|
||||||
|
const handleChangeSeed = (v: string | number) => dispatch(setSeed(Number(v)));
|
||||||
|
|
||||||
|
const handleClickRandomizeSeed = () =>
|
||||||
|
dispatch(setSeed(randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)));
|
||||||
|
|
||||||
|
const handleChangeShouldGenerateVariations = (
|
||||||
|
e: ChangeEvent<HTMLInputElement>
|
||||||
|
) => dispatch(setShouldGenerateVariations(e.target.checked));
|
||||||
|
|
||||||
|
const handleChangevariationAmount = (v: string | number) =>
|
||||||
|
dispatch(setVariationAmount(Number(v)));
|
||||||
|
|
||||||
|
const handleChangeSeedWeights = (e: ChangeEvent<HTMLInputElement>) =>
|
||||||
|
dispatch(setSeedWeights(e.target.value));
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2} direction={'column'}>
|
||||||
|
<SDNumberInput
|
||||||
|
label="Images to generate"
|
||||||
|
step={1}
|
||||||
|
min={1}
|
||||||
|
precision={0}
|
||||||
|
onChange={handleChangeIterations}
|
||||||
|
value={iterations}
|
||||||
|
/>
|
||||||
|
<SDSwitch
|
||||||
|
label="Randomize seed on generation"
|
||||||
|
isChecked={shouldRandomizeSeed}
|
||||||
|
onChange={handleChangeShouldRandomizeSeed}
|
||||||
|
/>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<SDNumberInput
|
||||||
|
label="Seed"
|
||||||
|
step={1}
|
||||||
|
precision={0}
|
||||||
|
flexGrow={1}
|
||||||
|
min={NUMPY_RAND_MIN}
|
||||||
|
max={NUMPY_RAND_MAX}
|
||||||
|
isDisabled={shouldRandomizeSeed}
|
||||||
|
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||||
|
onChange={handleChangeSeed}
|
||||||
|
value={seed}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
size={'sm'}
|
||||||
|
isDisabled={shouldRandomizeSeed}
|
||||||
|
onClick={handleClickRandomizeSeed}
|
||||||
|
>
|
||||||
|
<Text pl={2} pr={2}>
|
||||||
|
Shuffle
|
||||||
|
</Text>
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
<SDSwitch
|
||||||
|
label="Generate variations"
|
||||||
|
isChecked={shouldGenerateVariations}
|
||||||
|
width={'auto'}
|
||||||
|
onChange={handleChangeShouldGenerateVariations}
|
||||||
|
/>
|
||||||
|
<SDNumberInput
|
||||||
|
label="Variation amount"
|
||||||
|
value={variationAmount}
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChangevariationAmount}
|
||||||
|
/>
|
||||||
|
<FormControl
|
||||||
|
isInvalid={
|
||||||
|
shouldGenerateVariations &&
|
||||||
|
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
||||||
|
}
|
||||||
|
flexGrow={1}
|
||||||
|
>
|
||||||
|
<HStack>
|
||||||
|
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
||||||
|
<Text whiteSpace="nowrap">Seed Weights</Text>
|
||||||
|
</FormLabel>
|
||||||
|
<Input
|
||||||
|
size={'sm'}
|
||||||
|
value={seedWeights}
|
||||||
|
onChange={handleChangeSeedWeights}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SeedVariationOptions;
|
@ -1,24 +1,15 @@
|
|||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { SDMetadata } from '../gallery/gallerySlice';
|
import * as InvokeAI from '../../app/invokeai';
|
||||||
import randomInt from './util/randomInt';
|
import promptToString from '../../common/util/promptToString';
|
||||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
|
||||||
|
|
||||||
const calculateRealSteps = (
|
export type UpscalingLevel = 2 | 4;
|
||||||
steps: number,
|
|
||||||
strength: number,
|
|
||||||
hasInitImage: boolean
|
|
||||||
): number => {
|
|
||||||
return hasInitImage ? Math.floor(strength * steps) : steps;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type UpscalingLevel = 0 | 2 | 3 | 4;
|
export interface OptionsState {
|
||||||
|
|
||||||
export interface SDState {
|
|
||||||
prompt: string;
|
prompt: string;
|
||||||
iterations: number;
|
iterations: number;
|
||||||
steps: number;
|
steps: number;
|
||||||
realSteps: number;
|
|
||||||
cfgScale: number;
|
cfgScale: number;
|
||||||
height: number;
|
height: number;
|
||||||
width: number;
|
width: number;
|
||||||
@ -34,18 +25,17 @@ export interface SDState {
|
|||||||
seamless: boolean;
|
seamless: boolean;
|
||||||
shouldFitToWidthHeight: boolean;
|
shouldFitToWidthHeight: boolean;
|
||||||
shouldGenerateVariations: boolean;
|
shouldGenerateVariations: boolean;
|
||||||
variantAmount: number;
|
variationAmount: number;
|
||||||
seedWeights: string;
|
seedWeights: string;
|
||||||
shouldRunESRGAN: boolean;
|
shouldRunESRGAN: boolean;
|
||||||
shouldRunGFPGAN: boolean;
|
shouldRunGFPGAN: boolean;
|
||||||
shouldRandomizeSeed: boolean;
|
shouldRandomizeSeed: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialSDState: SDState = {
|
const initialOptionsState: OptionsState = {
|
||||||
prompt: '',
|
prompt: '',
|
||||||
iterations: 1,
|
iterations: 1,
|
||||||
steps: 50,
|
steps: 50,
|
||||||
realSteps: 50,
|
|
||||||
cfgScale: 7.5,
|
cfgScale: 7.5,
|
||||||
height: 512,
|
height: 512,
|
||||||
width: 512,
|
width: 512,
|
||||||
@ -58,7 +48,7 @@ const initialSDState: SDState = {
|
|||||||
maskPath: '',
|
maskPath: '',
|
||||||
shouldFitToWidthHeight: true,
|
shouldFitToWidthHeight: true,
|
||||||
shouldGenerateVariations: false,
|
shouldGenerateVariations: false,
|
||||||
variantAmount: 0.1,
|
variationAmount: 0.1,
|
||||||
seedWeights: '',
|
seedWeights: '',
|
||||||
shouldRunESRGAN: false,
|
shouldRunESRGAN: false,
|
||||||
upscalingLevel: 4,
|
upscalingLevel: 4,
|
||||||
@ -68,27 +58,25 @@ const initialSDState: SDState = {
|
|||||||
shouldRandomizeSeed: true,
|
shouldRandomizeSeed: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialState: SDState = initialSDState;
|
const initialState: OptionsState = initialOptionsState;
|
||||||
|
|
||||||
export const sdSlice = createSlice({
|
export const optionsSlice = createSlice({
|
||||||
name: 'sd',
|
name: 'options',
|
||||||
initialState,
|
initialState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setPrompt: (state, action: PayloadAction<string>) => {
|
setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
|
||||||
state.prompt = action.payload;
|
const newPrompt = action.payload;
|
||||||
|
if (typeof newPrompt === 'string') {
|
||||||
|
state.prompt = newPrompt;
|
||||||
|
} else {
|
||||||
|
state.prompt = promptToString(newPrompt);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
setIterations: (state, action: PayloadAction<number>) => {
|
setIterations: (state, action: PayloadAction<number>) => {
|
||||||
state.iterations = action.payload;
|
state.iterations = action.payload;
|
||||||
},
|
},
|
||||||
setSteps: (state, action: PayloadAction<number>) => {
|
setSteps: (state, action: PayloadAction<number>) => {
|
||||||
const { img2imgStrength, initialImagePath } = state;
|
state.steps = action.payload;
|
||||||
const steps = action.payload;
|
|
||||||
state.steps = steps;
|
|
||||||
state.realSteps = calculateRealSteps(
|
|
||||||
steps,
|
|
||||||
img2imgStrength,
|
|
||||||
Boolean(initialImagePath)
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
setCfgScale: (state, action: PayloadAction<number>) => {
|
||||||
state.cfgScale = action.payload;
|
state.cfgScale = action.payload;
|
||||||
@ -107,14 +95,7 @@ export const sdSlice = createSlice({
|
|||||||
state.shouldRandomizeSeed = false;
|
state.shouldRandomizeSeed = false;
|
||||||
},
|
},
|
||||||
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
setImg2imgStrength: (state, action: PayloadAction<number>) => {
|
||||||
const img2imgStrength = action.payload;
|
state.img2imgStrength = action.payload;
|
||||||
const { steps, initialImagePath } = state;
|
|
||||||
state.img2imgStrength = img2imgStrength;
|
|
||||||
state.realSteps = calculateRealSteps(
|
|
||||||
steps,
|
|
||||||
img2imgStrength,
|
|
||||||
Boolean(initialImagePath)
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
setGfpganStrength: (state, action: PayloadAction<number>) => {
|
setGfpganStrength: (state, action: PayloadAction<number>) => {
|
||||||
state.gfpganStrength = action.payload;
|
state.gfpganStrength = action.payload;
|
||||||
@ -129,15 +110,9 @@ export const sdSlice = createSlice({
|
|||||||
state.shouldUseInitImage = action.payload;
|
state.shouldUseInitImage = action.payload;
|
||||||
},
|
},
|
||||||
setInitialImagePath: (state, action: PayloadAction<string>) => {
|
setInitialImagePath: (state, action: PayloadAction<string>) => {
|
||||||
const initialImagePath = action.payload;
|
const newInitialImagePath = action.payload;
|
||||||
const { steps, img2imgStrength } = state;
|
state.shouldUseInitImage = newInitialImagePath ? true : false;
|
||||||
state.shouldUseInitImage = initialImagePath ? true : false;
|
state.initialImagePath = newInitialImagePath;
|
||||||
state.initialImagePath = initialImagePath;
|
|
||||||
state.realSteps = calculateRealSteps(
|
|
||||||
steps,
|
|
||||||
img2imgStrength,
|
|
||||||
Boolean(initialImagePath)
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
setMaskPath: (state, action: PayloadAction<string>) => {
|
setMaskPath: (state, action: PayloadAction<string>) => {
|
||||||
state.maskPath = action.payload;
|
state.maskPath = action.payload;
|
||||||
@ -151,13 +126,11 @@ export const sdSlice = createSlice({
|
|||||||
resetSeed: (state) => {
|
resetSeed: (state) => {
|
||||||
state.seed = -1;
|
state.seed = -1;
|
||||||
},
|
},
|
||||||
randomizeSeed: (state) => {
|
|
||||||
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
|
|
||||||
},
|
|
||||||
setParameter: (
|
setParameter: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ key: string; value: string | number | boolean }>
|
action: PayloadAction<{ key: string; value: string | number | boolean }>
|
||||||
) => {
|
) => {
|
||||||
|
// TODO: This probably needs to be refactored.
|
||||||
const { key, value } = action.payload;
|
const { key, value } = action.payload;
|
||||||
const temp = { ...state, [key]: value };
|
const temp = { ...state, [key]: value };
|
||||||
if (key === 'seed') {
|
if (key === 'seed') {
|
||||||
@ -171,70 +144,95 @@ export const sdSlice = createSlice({
|
|||||||
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
|
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldGenerateVariations = action.payload;
|
state.shouldGenerateVariations = action.payload;
|
||||||
},
|
},
|
||||||
setVariantAmount: (state, action: PayloadAction<number>) => {
|
setVariationAmount: (state, action: PayloadAction<number>) => {
|
||||||
state.variantAmount = action.payload;
|
state.variationAmount = action.payload;
|
||||||
},
|
},
|
||||||
setSeedWeights: (state, action: PayloadAction<string>) => {
|
setSeedWeights: (state, action: PayloadAction<string>) => {
|
||||||
state.seedWeights = action.payload;
|
state.seedWeights = action.payload;
|
||||||
},
|
},
|
||||||
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
|
setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
|
||||||
const {
|
const {
|
||||||
prompt,
|
type,
|
||||||
steps,
|
postprocessing,
|
||||||
cfgScale,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
sampler,
|
sampler,
|
||||||
|
prompt,
|
||||||
seed,
|
seed,
|
||||||
img2imgStrength,
|
variations,
|
||||||
gfpganStrength,
|
steps,
|
||||||
upscalingLevel,
|
cfg_scale,
|
||||||
upscalingStrength,
|
|
||||||
initialImagePath,
|
|
||||||
maskPath,
|
|
||||||
seamless,
|
seamless,
|
||||||
shouldFitToWidthHeight,
|
width,
|
||||||
} = action.payload;
|
height,
|
||||||
|
strength,
|
||||||
|
fit,
|
||||||
|
init_image_path,
|
||||||
|
mask_image_path,
|
||||||
|
} = action.payload.image;
|
||||||
|
|
||||||
// ?? = falsy values ('', 0, etc) are used
|
if (type === 'img2img') {
|
||||||
// || = falsy values not used
|
if (init_image_path) state.initialImagePath = init_image_path;
|
||||||
state.prompt = prompt ?? state.prompt;
|
if (mask_image_path) state.maskPath = mask_image_path;
|
||||||
state.steps = steps || state.steps;
|
if (strength) state.img2imgStrength = strength;
|
||||||
state.cfgScale = cfgScale || state.cfgScale;
|
if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
|
||||||
state.width = width || state.width;
|
state.shouldUseInitImage = true;
|
||||||
state.height = height || state.height;
|
} else {
|
||||||
state.sampler = sampler || state.sampler;
|
state.shouldUseInitImage = false;
|
||||||
state.seed = seed ?? state.seed;
|
}
|
||||||
state.seamless = seamless ?? state.seamless;
|
|
||||||
state.shouldFitToWidthHeight =
|
if (variations && variations.length > 0) {
|
||||||
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight;
|
state.seedWeights = seedWeightsToString(variations);
|
||||||
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength;
|
state.shouldGenerateVariations = true;
|
||||||
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength;
|
} else {
|
||||||
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel;
|
state.shouldGenerateVariations = false;
|
||||||
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength;
|
}
|
||||||
state.initialImagePath = initialImagePath ?? state.initialImagePath;
|
|
||||||
state.maskPath = maskPath ?? state.maskPath;
|
|
||||||
|
|
||||||
// If the image whose parameters we are using has a seed, disable randomizing the seed
|
|
||||||
if (seed) {
|
if (seed) {
|
||||||
|
state.seed = seed;
|
||||||
state.shouldRandomizeSeed = false;
|
state.shouldRandomizeSeed = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have a gfpgan strength, enable it
|
let postprocessingNotDone = ['gfpgan', 'esrgan'];
|
||||||
state.shouldRunGFPGAN = gfpganStrength ? true : false;
|
if (postprocessing && postprocessing.length > 0) {
|
||||||
|
postprocessing.forEach(
|
||||||
|
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
|
||||||
|
if (postprocess.type === 'gfpgan') {
|
||||||
|
const { strength } = postprocess;
|
||||||
|
if (strength) state.gfpganStrength = strength;
|
||||||
|
state.shouldRunGFPGAN = true;
|
||||||
|
postprocessingNotDone = postprocessingNotDone.filter(
|
||||||
|
(p) => p !== 'gfpgan'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (postprocess.type === 'esrgan') {
|
||||||
|
const { scale, strength } = postprocess;
|
||||||
|
if (scale) state.upscalingLevel = scale;
|
||||||
|
if (strength) state.upscalingStrength = strength;
|
||||||
|
state.shouldRunESRGAN = true;
|
||||||
|
postprocessingNotDone = postprocessingNotDone.filter(
|
||||||
|
(p) => p !== 'esrgan'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// if we have a esrgan strength, enable it
|
postprocessingNotDone.forEach((p) => {
|
||||||
state.shouldRunESRGAN = upscalingLevel ? true : false;
|
if (p === 'esrgan') state.shouldRunESRGAN = false;
|
||||||
|
if (p === 'gfpgan') state.shouldRunGFPGAN = false;
|
||||||
|
});
|
||||||
|
|
||||||
// if we want to recreate an image exactly, we disable variations
|
if (prompt) state.prompt = promptToString(prompt);
|
||||||
state.shouldGenerateVariations = false;
|
if (sampler) state.sampler = sampler;
|
||||||
|
if (steps) state.steps = steps;
|
||||||
state.shouldUseInitImage = initialImagePath ? true : false;
|
if (cfg_scale) state.cfgScale = cfg_scale;
|
||||||
|
if (typeof seamless === 'boolean') state.seamless = seamless;
|
||||||
|
if (width) state.width = width;
|
||||||
|
if (height) state.height = height;
|
||||||
},
|
},
|
||||||
resetSDState: (state) => {
|
resetOptionsState: (state) => {
|
||||||
return {
|
return {
|
||||||
...state,
|
...state,
|
||||||
...initialSDState,
|
...initialOptionsState,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
|
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
|
||||||
@ -267,17 +265,16 @@ export const {
|
|||||||
setInitialImagePath,
|
setInitialImagePath,
|
||||||
setMaskPath,
|
setMaskPath,
|
||||||
resetSeed,
|
resetSeed,
|
||||||
randomizeSeed,
|
resetOptionsState,
|
||||||
resetSDState,
|
|
||||||
setShouldFitToWidthHeight,
|
setShouldFitToWidthHeight,
|
||||||
setParameter,
|
setParameter,
|
||||||
setShouldGenerateVariations,
|
setShouldGenerateVariations,
|
||||||
setSeedWeights,
|
setSeedWeights,
|
||||||
setVariantAmount,
|
setVariationAmount,
|
||||||
setAllParameters,
|
setAllParameters,
|
||||||
setShouldRunGFPGAN,
|
setShouldRunGFPGAN,
|
||||||
setShouldRunESRGAN,
|
setShouldRunESRGAN,
|
||||||
setShouldRandomizeSeed,
|
setShouldRandomizeSeed,
|
||||||
} = sdSlice.actions;
|
} = optionsSlice.actions;
|
||||||
|
|
||||||
export default sdSlice.reducer;
|
export default optionsSlice.reducer;
|
@ -1,84 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
|
|
||||||
import {
|
|
||||||
setUpscalingLevel,
|
|
||||||
setUpscalingStrength,
|
|
||||||
UpscalingLevel,
|
|
||||||
SDState,
|
|
||||||
} from '../sd/sdSlice';
|
|
||||||
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
import SDSelect from '../../components/SDSelect';
|
|
||||||
|
|
||||||
import { UPSCALING_LEVELS } from '../../app/constants';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { SystemState } from '../system/systemSlice';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
upscalingLevel: sd.upscalingLevel,
|
|
||||||
upscalingStrength: sd.upscalingStrength,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isESRGANAvailable: system.isESRGANAvailable,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
const ESRGANOptions = () => {
|
|
||||||
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const { isESRGANAvailable } = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex direction={'column'} gap={2}>
|
|
||||||
<SDSelect
|
|
||||||
isDisabled={!isESRGANAvailable}
|
|
||||||
label='Scale'
|
|
||||||
value={upscalingLevel}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(
|
|
||||||
setUpscalingLevel(
|
|
||||||
Number(e.target.value) as UpscalingLevel
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
validValues={UPSCALING_LEVELS}
|
|
||||||
/>
|
|
||||||
<SDNumberInput
|
|
||||||
isDisabled={!isESRGANAvailable}
|
|
||||||
label='Strength'
|
|
||||||
step={0.05}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
|
|
||||||
value={upscalingStrength}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ESRGANOptions;
|
|
@ -1,63 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
|
|
||||||
import { SDState, setGfpganStrength } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { SystemState } from '../system/systemSlice';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
gfpganStrength: sd.gfpganStrength,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
const GFPGANOptions = () => {
|
|
||||||
const { gfpganStrength } = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const { isGFPGANAvailable } = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex direction={'column'} gap={2}>
|
|
||||||
<SDNumberInput
|
|
||||||
isDisabled={!isGFPGANAvailable}
|
|
||||||
label='Strength'
|
|
||||||
step={0.05}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
|
|
||||||
value={gfpganStrength}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default GFPGANOptions;
|
|
@ -1,54 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
import SDSwitch from '../../components/SDSwitch';
|
|
||||||
import InitImage from './InitImage';
|
|
||||||
import {
|
|
||||||
SDState,
|
|
||||||
setImg2imgStrength,
|
|
||||||
setShouldFitToWidthHeight,
|
|
||||||
} from './sdSlice';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
initialImagePath: sd.initialImagePath,
|
|
||||||
img2imgStrength: sd.img2imgStrength,
|
|
||||||
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ImageToImageOptions = () => {
|
|
||||||
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
|
|
||||||
useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
return (
|
|
||||||
<Flex direction={'column'} gap={2}>
|
|
||||||
<SDNumberInput
|
|
||||||
isDisabled={!initialImagePath}
|
|
||||||
label='Strength'
|
|
||||||
step={0.01}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
|
|
||||||
value={img2imgStrength}
|
|
||||||
/>
|
|
||||||
<SDSwitch
|
|
||||||
isDisabled={!initialImagePath}
|
|
||||||
label='Fit initial image to output size'
|
|
||||||
isChecked={shouldFitToWidthHeight}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setShouldFitToWidthHeight(e.target.checked))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<InitImage />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ImageToImageOptions;
|
|
@ -1,155 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
Flex,
|
|
||||||
IconButton,
|
|
||||||
Image,
|
|
||||||
useToast,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { SyntheticEvent, useCallback, useState } from 'react';
|
|
||||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
|
||||||
import { FaTrash } from 'react-icons/fa';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import {
|
|
||||||
SDState,
|
|
||||||
setInitialImagePath,
|
|
||||||
setMaskPath,
|
|
||||||
} from '../../features/sd/sdSlice';
|
|
||||||
import MaskUploader from './MaskUploader';
|
|
||||||
import './InitImage.css';
|
|
||||||
import { uploadInitialImage } from '../../app/socketio';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
initialImagePath: sd.initialImagePath,
|
|
||||||
maskPath: sd.maskPath,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{ memoizeOptions: { resultEqualityCheck: isEqual } }
|
|
||||||
);
|
|
||||||
|
|
||||||
const InitImage = () => {
|
|
||||||
const toast = useToast();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const onDrop = useCallback(
|
|
||||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
|
||||||
fileRejections.forEach((rejection: FileRejection) => {
|
|
||||||
const msg = rejection.errors.reduce(
|
|
||||||
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
|
|
||||||
''
|
|
||||||
);
|
|
||||||
|
|
||||||
toast({
|
|
||||||
title: 'Upload failed',
|
|
||||||
description: msg,
|
|
||||||
status: 'error',
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
acceptedFiles.forEach((file: File) => {
|
|
||||||
dispatch(uploadInitialImage(file));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, toast]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { getRootProps, getInputProps, open } = useDropzone({
|
|
||||||
onDrop,
|
|
||||||
accept: {
|
|
||||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
|
|
||||||
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
open();
|
|
||||||
};
|
|
||||||
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
dispatch(setInitialImagePath(''));
|
|
||||||
dispatch(setMaskPath(''));
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleMouseOverInitialImageUploadButton = () =>
|
|
||||||
setShouldShowMask(false);
|
|
||||||
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
|
|
||||||
|
|
||||||
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
|
|
||||||
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
{...getRootProps({
|
|
||||||
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
|
|
||||||
})}
|
|
||||||
direction={'column'}
|
|
||||||
alignItems={'center'}
|
|
||||||
gap={2}
|
|
||||||
>
|
|
||||||
<input {...getInputProps({ multiple: false })} />
|
|
||||||
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
|
|
||||||
<Button
|
|
||||||
size={'sm'}
|
|
||||||
fontSize={'md'}
|
|
||||||
fontWeight={'normal'}
|
|
||||||
onClick={handleClickUploadIcon}
|
|
||||||
onMouseOver={handleMouseOverInitialImageUploadButton}
|
|
||||||
onMouseOut={handleMouseOutInitialImageUploadButton}
|
|
||||||
>
|
|
||||||
Upload Image
|
|
||||||
</Button>
|
|
||||||
|
|
||||||
<MaskUploader>
|
|
||||||
<Button
|
|
||||||
size={'sm'}
|
|
||||||
fontSize={'md'}
|
|
||||||
fontWeight={'normal'}
|
|
||||||
onClick={handleClickUploadIcon}
|
|
||||||
onMouseOver={handleMouseOverMaskUploadButton}
|
|
||||||
onMouseOut={handleMouseOutMaskUploadButton}
|
|
||||||
>
|
|
||||||
Upload Mask
|
|
||||||
</Button>
|
|
||||||
</MaskUploader>
|
|
||||||
<IconButton
|
|
||||||
size={'sm'}
|
|
||||||
aria-label={'Reset initial image and mask'}
|
|
||||||
onClick={handleClickResetInitialImageAndMask}
|
|
||||||
icon={<FaTrash />}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
{initialImagePath && (
|
|
||||||
<Flex position={'relative'} width={'100%'}>
|
|
||||||
<Image
|
|
||||||
fit={'contain'}
|
|
||||||
src={initialImagePath}
|
|
||||||
rounded={'md'}
|
|
||||||
className={'checkerboard'}
|
|
||||||
/>
|
|
||||||
{shouldShowMask && maskPath && (
|
|
||||||
<Image
|
|
||||||
position={'absolute'}
|
|
||||||
top={0}
|
|
||||||
left={0}
|
|
||||||
fit={'contain'}
|
|
||||||
src={maskPath}
|
|
||||||
rounded={'md'}
|
|
||||||
zIndex={1}
|
|
||||||
className={'checkerboard'}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default InitImage;
|
|
@ -1,61 +0,0 @@
|
|||||||
import { useToast } from '@chakra-ui/react';
|
|
||||||
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
|
|
||||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
|
||||||
import { useAppDispatch } from '../../app/hooks';
|
|
||||||
import { uploadMaskImage } from '../../app/socketio';
|
|
||||||
|
|
||||||
type Props = {
|
|
||||||
children: ReactElement;
|
|
||||||
};
|
|
||||||
|
|
||||||
const MaskUploader = ({ children }: Props) => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const toast = useToast();
|
|
||||||
|
|
||||||
const onDrop = useCallback(
|
|
||||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
|
||||||
fileRejections.forEach((rejection: FileRejection) => {
|
|
||||||
const msg = rejection.errors.reduce(
|
|
||||||
(acc: string, cur: { message: string }) =>
|
|
||||||
acc + '\n' + cur.message,
|
|
||||||
''
|
|
||||||
);
|
|
||||||
|
|
||||||
toast({
|
|
||||||
title: 'Upload failed',
|
|
||||||
description: msg,
|
|
||||||
status: 'error',
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
acceptedFiles.forEach((file: File) => {
|
|
||||||
dispatch(uploadMaskImage(file));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, toast]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { getRootProps, getInputProps, open } = useDropzone({
|
|
||||||
onDrop,
|
|
||||||
accept: {
|
|
||||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleClickUploadIcon = (e: SyntheticEvent) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
open();
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div {...getRootProps()}>
|
|
||||||
<input {...getInputProps({ multiple: false })} />
|
|
||||||
{cloneElement(children, {
|
|
||||||
onClick: handleClickUploadIcon,
|
|
||||||
})}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default MaskUploader;
|
|
@ -1,211 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
Box,
|
|
||||||
Text,
|
|
||||||
Accordion,
|
|
||||||
AccordionItem,
|
|
||||||
AccordionButton,
|
|
||||||
AccordionIcon,
|
|
||||||
AccordionPanel,
|
|
||||||
Switch,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
|
|
||||||
import {
|
|
||||||
setShouldRunGFPGAN,
|
|
||||||
setShouldRunESRGAN,
|
|
||||||
SDState,
|
|
||||||
setShouldUseInitImage,
|
|
||||||
} from '../sd/sdSlice';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { setOpenAccordions, SystemState } from '../system/systemSlice';
|
|
||||||
import SeedVariationOptions from './SeedVariationOptions';
|
|
||||||
import SamplerOptions from './SamplerOptions';
|
|
||||||
import ESRGANOptions from './ESRGANOptions';
|
|
||||||
import GFPGANOptions from './GFPGANOptions';
|
|
||||||
import OutputOptions from './OutputOptions';
|
|
||||||
import ImageToImageOptions from './ImageToImageOptions';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
initialImagePath: sd.initialImagePath,
|
|
||||||
shouldUseInitImage: sd.shouldUseInitImage,
|
|
||||||
shouldRunESRGAN: sd.shouldRunESRGAN,
|
|
||||||
shouldRunGFPGAN: sd.shouldRunGFPGAN,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isGFPGANAvailable: system.isGFPGANAvailable,
|
|
||||||
isESRGANAvailable: system.isESRGANAvailable,
|
|
||||||
openAccordions: system.openAccordions,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const OptionsAccordion = () => {
|
|
||||||
const {
|
|
||||||
shouldRunESRGAN,
|
|
||||||
shouldRunGFPGAN,
|
|
||||||
shouldUseInitImage,
|
|
||||||
initialImagePath,
|
|
||||||
} = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
|
|
||||||
useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Accordion
|
|
||||||
defaultIndex={openAccordions}
|
|
||||||
allowMultiple
|
|
||||||
reduceMotion
|
|
||||||
onChange={(openAccordions) =>
|
|
||||||
dispatch(setOpenAccordions(openAccordions))
|
|
||||||
}
|
|
||||||
>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Box flex='1' textAlign='left'>
|
|
||||||
Seed & Variation
|
|
||||||
</Box>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<SeedVariationOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Box flex='1' textAlign='left'>
|
|
||||||
Sampler
|
|
||||||
</Box>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<SamplerOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Flex
|
|
||||||
justifyContent={'space-between'}
|
|
||||||
alignItems={'center'}
|
|
||||||
width={'100%'}
|
|
||||||
mr={2}
|
|
||||||
>
|
|
||||||
<Text>Upscale (ESRGAN)</Text>
|
|
||||||
<Switch
|
|
||||||
isDisabled={!isESRGANAvailable}
|
|
||||||
isChecked={shouldRunESRGAN}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(
|
|
||||||
setShouldRunESRGAN(e.target.checked)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<ESRGANOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Flex
|
|
||||||
justifyContent={'space-between'}
|
|
||||||
alignItems={'center'}
|
|
||||||
width={'100%'}
|
|
||||||
mr={2}
|
|
||||||
>
|
|
||||||
<Text>Fix Faces (GFPGAN)</Text>
|
|
||||||
<Switch
|
|
||||||
isDisabled={!isGFPGANAvailable}
|
|
||||||
isChecked={shouldRunGFPGAN}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(
|
|
||||||
setShouldRunGFPGAN(e.target.checked)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<GFPGANOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Flex
|
|
||||||
justifyContent={'space-between'}
|
|
||||||
alignItems={'center'}
|
|
||||||
width={'100%'}
|
|
||||||
mr={2}
|
|
||||||
>
|
|
||||||
<Text>Image to Image</Text>
|
|
||||||
<Switch
|
|
||||||
isDisabled={!initialImagePath}
|
|
||||||
isChecked={shouldUseInitImage}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(
|
|
||||||
setShouldUseInitImage(e.target.checked)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<ImageToImageOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
<AccordionItem>
|
|
||||||
<h2>
|
|
||||||
<AccordionButton>
|
|
||||||
<Box flex='1' textAlign='left'>
|
|
||||||
Output
|
|
||||||
</Box>
|
|
||||||
<AccordionIcon />
|
|
||||||
</AccordionButton>
|
|
||||||
</h2>
|
|
||||||
<AccordionPanel>
|
|
||||||
<OutputOptions />
|
|
||||||
</AccordionPanel>
|
|
||||||
</AccordionItem>
|
|
||||||
</Accordion>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default OptionsAccordion;
|
|
@ -1,66 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
|
|
||||||
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
import SDSelect from '../../components/SDSelect';
|
|
||||||
|
|
||||||
import { HEIGHTS, WIDTHS } from '../../app/constants';
|
|
||||||
import SDSwitch from '../../components/SDSwitch';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
height: sd.height,
|
|
||||||
width: sd.width,
|
|
||||||
seamless: sd.seamless,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const OutputOptions = () => {
|
|
||||||
const { height, width, seamless } = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} direction={'column'}>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<SDSelect
|
|
||||||
label='Width'
|
|
||||||
value={width}
|
|
||||||
flexGrow={1}
|
|
||||||
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
|
|
||||||
validValues={WIDTHS}
|
|
||||||
/>
|
|
||||||
<SDSelect
|
|
||||||
label='Height'
|
|
||||||
value={height}
|
|
||||||
flexGrow={1}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setHeight(Number(e.target.value)))
|
|
||||||
}
|
|
||||||
validValues={HEIGHTS}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
<SDSwitch
|
|
||||||
label='Seamless tiling'
|
|
||||||
fontSize={'md'}
|
|
||||||
isChecked={seamless}
|
|
||||||
onChange={(e) => dispatch(setSeamless(e.target.checked))}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default OutputOptions;
|
|
@ -1,58 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { cancelProcessing, generateImage } from '../../app/socketio';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SDButton from '../../components/SDButton';
|
|
||||||
import { SystemState } from '../system/systemSlice';
|
|
||||||
import useCheckParameters from '../system/useCheckParameters';
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isProcessing: system.isProcessing,
|
|
||||||
isConnected: system.isConnected,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ProcessButtons = () => {
|
|
||||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const isReady = useCheckParameters();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
|
|
||||||
<SDButton
|
|
||||||
label='Generate'
|
|
||||||
type='submit'
|
|
||||||
colorScheme='green'
|
|
||||||
flexGrow={1}
|
|
||||||
isDisabled={!isReady}
|
|
||||||
fontSize={'md'}
|
|
||||||
size={'md'}
|
|
||||||
onClick={() => dispatch(generateImage())}
|
|
||||||
/>
|
|
||||||
<SDButton
|
|
||||||
label='Cancel'
|
|
||||||
colorScheme='red'
|
|
||||||
flexGrow={1}
|
|
||||||
fontSize={'md'}
|
|
||||||
size={'md'}
|
|
||||||
isDisabled={!isConnected || !isProcessing}
|
|
||||||
onClick={() => dispatch(cancelProcessing())}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ProcessButtons;
|
|
@ -1,25 +0,0 @@
|
|||||||
import { Textarea } from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { setPrompt } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
const PromptInput = () => {
|
|
||||||
const { prompt } = useAppSelector((state: RootState) => state.sd);
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Textarea
|
|
||||||
id='prompt'
|
|
||||||
name='prompt'
|
|
||||||
resize='none'
|
|
||||||
size={'lg'}
|
|
||||||
height={'100%'}
|
|
||||||
isInvalid={!prompt.length}
|
|
||||||
onChange={(e) => dispatch(setPrompt(e.target.value))}
|
|
||||||
value={prompt}
|
|
||||||
placeholder="I'm dreaming of..."
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default PromptInput;
|
|
@ -1,51 +0,0 @@
|
|||||||
import {
|
|
||||||
Slider,
|
|
||||||
SliderTrack,
|
|
||||||
SliderFilledTrack,
|
|
||||||
SliderThumb,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Text,
|
|
||||||
Flex,
|
|
||||||
SliderProps,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
|
|
||||||
interface Props extends SliderProps {
|
|
||||||
label: string;
|
|
||||||
value: number;
|
|
||||||
fontSize?: number | string;
|
|
||||||
}
|
|
||||||
|
|
||||||
const SDSlider = ({
|
|
||||||
label,
|
|
||||||
value,
|
|
||||||
fontSize = 'sm',
|
|
||||||
onChange,
|
|
||||||
...rest
|
|
||||||
}: Props) => {
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
|
||||||
<Text fontSize={fontSize} whiteSpace='nowrap'>
|
|
||||||
{label}
|
|
||||||
</Text>
|
|
||||||
</FormLabel>
|
|
||||||
<Slider
|
|
||||||
aria-label={label}
|
|
||||||
focusThumbOnChange={true}
|
|
||||||
value={value}
|
|
||||||
onChange={onChange}
|
|
||||||
{...rest}
|
|
||||||
>
|
|
||||||
<SliderTrack>
|
|
||||||
<SliderFilledTrack />
|
|
||||||
</SliderTrack>
|
|
||||||
<SliderThumb />
|
|
||||||
</Slider>
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SDSlider;
|
|
@ -1,62 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
|
|
||||||
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
|
|
||||||
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
import SDSelect from '../../components/SDSelect';
|
|
||||||
|
|
||||||
import { SAMPLERS } from '../../app/constants';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
steps: sd.steps,
|
|
||||||
cfgScale: sd.cfgScale,
|
|
||||||
sampler: sd.sampler,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const SamplerOptions = () => {
|
|
||||||
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} direction={'column'}>
|
|
||||||
<SDNumberInput
|
|
||||||
label='Steps'
|
|
||||||
min={1}
|
|
||||||
step={1}
|
|
||||||
precision={0}
|
|
||||||
onChange={(v) => dispatch(setSteps(Number(v)))}
|
|
||||||
value={steps}
|
|
||||||
/>
|
|
||||||
<SDNumberInput
|
|
||||||
label='CFG scale'
|
|
||||||
step={0.5}
|
|
||||||
onChange={(v) => dispatch(setCfgScale(Number(v)))}
|
|
||||||
value={cfgScale}
|
|
||||||
/>
|
|
||||||
<SDSelect
|
|
||||||
label='Sampler'
|
|
||||||
value={sampler}
|
|
||||||
onChange={(e) => dispatch(setSampler(e.target.value))}
|
|
||||||
validValues={SAMPLERS}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SamplerOptions;
|
|
@ -1,144 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
Input,
|
|
||||||
HStack,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Text,
|
|
||||||
Button,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
import SDSwitch from '../../components/SDSwitch';
|
|
||||||
import {
|
|
||||||
randomizeSeed,
|
|
||||||
SDState,
|
|
||||||
setIterations,
|
|
||||||
setSeed,
|
|
||||||
setSeedWeights,
|
|
||||||
setShouldGenerateVariations,
|
|
||||||
setShouldRandomizeSeed,
|
|
||||||
setVariantAmount,
|
|
||||||
} from './sdSlice';
|
|
||||||
import { validateSeedWeights } from './util/seedWeightPairs';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
variantAmount: sd.variantAmount,
|
|
||||||
seedWeights: sd.seedWeights,
|
|
||||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
|
||||||
shouldRandomizeSeed: sd.shouldRandomizeSeed,
|
|
||||||
seed: sd.seed,
|
|
||||||
iterations: sd.iterations,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const SeedVariationOptions = () => {
|
|
||||||
const {
|
|
||||||
shouldGenerateVariations,
|
|
||||||
variantAmount,
|
|
||||||
seedWeights,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
seed,
|
|
||||||
iterations,
|
|
||||||
} = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} direction={'column'}>
|
|
||||||
<SDNumberInput
|
|
||||||
label='Images to generate'
|
|
||||||
step={1}
|
|
||||||
min={1}
|
|
||||||
precision={0}
|
|
||||||
onChange={(v) => dispatch(setIterations(Number(v)))}
|
|
||||||
value={iterations}
|
|
||||||
/>
|
|
||||||
<SDSwitch
|
|
||||||
label='Randomize seed on generation'
|
|
||||||
isChecked={shouldRandomizeSeed}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setShouldRandomizeSeed(e.target.checked))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<SDNumberInput
|
|
||||||
label='Seed'
|
|
||||||
step={1}
|
|
||||||
precision={0}
|
|
||||||
flexGrow={1}
|
|
||||||
min={NUMPY_RAND_MIN}
|
|
||||||
max={NUMPY_RAND_MAX}
|
|
||||||
isDisabled={shouldRandomizeSeed}
|
|
||||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
|
||||||
onChange={(v) => dispatch(setSeed(Number(v)))}
|
|
||||||
value={seed}
|
|
||||||
/>
|
|
||||||
<Button
|
|
||||||
size={'sm'}
|
|
||||||
isDisabled={shouldRandomizeSeed}
|
|
||||||
onClick={() => dispatch(randomizeSeed())}
|
|
||||||
>
|
|
||||||
<Text pl={2} pr={2}>
|
|
||||||
Shuffle
|
|
||||||
</Text>
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
<SDSwitch
|
|
||||||
label='Generate variations'
|
|
||||||
isChecked={shouldGenerateVariations}
|
|
||||||
width={'auto'}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setShouldGenerateVariations(e.target.checked))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<SDNumberInput
|
|
||||||
label='Variation amount'
|
|
||||||
value={variantAmount}
|
|
||||||
step={0.01}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
isDisabled={!shouldGenerateVariations}
|
|
||||||
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
|
||||||
/>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={
|
|
||||||
shouldGenerateVariations &&
|
|
||||||
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
|
||||||
}
|
|
||||||
flexGrow={1}
|
|
||||||
isDisabled={!shouldGenerateVariations}
|
|
||||||
>
|
|
||||||
<HStack>
|
|
||||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
|
||||||
<Text whiteSpace='nowrap'>
|
|
||||||
Seed Weights
|
|
||||||
</Text>
|
|
||||||
</FormLabel>
|
|
||||||
<Input
|
|
||||||
size={'sm'}
|
|
||||||
value={seedWeights}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setSeedWeights(e.target.value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SeedVariationOptions;
|
|
@ -1,92 +0,0 @@
|
|||||||
import {
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
HStack,
|
|
||||||
Input,
|
|
||||||
Text,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import SDNumberInput from '../../components/SDNumberInput';
|
|
||||||
import SDSwitch from '../../components/SDSwitch';
|
|
||||||
import {
|
|
||||||
SDState,
|
|
||||||
setSeedWeights,
|
|
||||||
setShouldGenerateVariations,
|
|
||||||
setVariantAmount,
|
|
||||||
} from './sdSlice';
|
|
||||||
import { validateSeedWeights } from './util/seedWeightPairs';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
variantAmount: sd.variantAmount,
|
|
||||||
seedWeights: sd.seedWeights,
|
|
||||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const Variant = () => {
|
|
||||||
const { shouldGenerateVariations, variantAmount, seedWeights } =
|
|
||||||
useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} alignItems={'center'} pl={1}>
|
|
||||||
<SDSwitch
|
|
||||||
label='Generate variations'
|
|
||||||
isChecked={shouldGenerateVariations}
|
|
||||||
width={'auto'}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setShouldGenerateVariations(e.target.checked))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<SDNumberInput
|
|
||||||
label='Amount'
|
|
||||||
value={variantAmount}
|
|
||||||
step={0.01}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
width={240}
|
|
||||||
isDisabled={!shouldGenerateVariations}
|
|
||||||
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
|
|
||||||
/>
|
|
||||||
<FormControl
|
|
||||||
isInvalid={
|
|
||||||
shouldGenerateVariations &&
|
|
||||||
!(validateSeedWeights(seedWeights) || seedWeights === '')
|
|
||||||
}
|
|
||||||
flexGrow={1}
|
|
||||||
isDisabled={!shouldGenerateVariations}
|
|
||||||
>
|
|
||||||
<HStack>
|
|
||||||
<FormLabel marginInlineEnd={0} marginBottom={1}>
|
|
||||||
<Text fontSize={'sm'} whiteSpace='nowrap'>
|
|
||||||
Seed Weights
|
|
||||||
</Text>
|
|
||||||
</FormLabel>
|
|
||||||
<Input
|
|
||||||
size={'sm'}
|
|
||||||
value={seedWeights}
|
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(setSeedWeights(e.target.value))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default Variant;
|
|
@ -1,56 +0,0 @@
|
|||||||
export interface SeedWeightPair {
|
|
||||||
seed: number;
|
|
||||||
weight: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export type SeedWeights = Array<Array<number>>;
|
|
||||||
|
|
||||||
export const stringToSeedWeights = (string: string): SeedWeights | boolean => {
|
|
||||||
const stringPairs = string.split(',');
|
|
||||||
const arrPairs = stringPairs.map((p) => p.split(':'));
|
|
||||||
const pairs = arrPairs.map((p) => {
|
|
||||||
return [parseInt(p[0]), parseFloat(p[1])];
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!validateSeedWeights(pairs)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairs;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const validateSeedWeights = (
|
|
||||||
seedWeights: SeedWeights | string
|
|
||||||
): boolean => {
|
|
||||||
return typeof seedWeights === 'string'
|
|
||||||
? Boolean(stringToSeedWeights(seedWeights))
|
|
||||||
: Boolean(
|
|
||||||
seedWeights.length &&
|
|
||||||
!seedWeights.some((pair) => {
|
|
||||||
const [seed, weight] = pair;
|
|
||||||
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
|
|
||||||
const isWeightValid =
|
|
||||||
!isNaN(parseInt(weight.toString(), 10)) &&
|
|
||||||
weight >= 0 &&
|
|
||||||
weight <= 1;
|
|
||||||
return !(isSeedValid && isWeightValid);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const seedWeightsToString = (
|
|
||||||
seedWeights: SeedWeights
|
|
||||||
): string | boolean => {
|
|
||||||
if (!validateSeedWeights(seedWeights)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return seedWeights.reduce((acc, pair, i, arr) => {
|
|
||||||
const [seed, weight] = pair;
|
|
||||||
acc += `${seed}:${weight}`;
|
|
||||||
if (i !== arr.length - 1) {
|
|
||||||
acc += ',';
|
|
||||||
}
|
|
||||||
return acc;
|
|
||||||
}, '');
|
|
||||||
};
|
|
@ -1,11 +1,11 @@
|
|||||||
import {
|
import {
|
||||||
IconButton,
|
IconButton,
|
||||||
useColorModeValue,
|
useColorModeValue,
|
||||||
Flex,
|
Flex,
|
||||||
Text,
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
import { RootState } from '../../app/store';
|
import { RootState } from '../../app/store';
|
||||||
import { setShouldShowLogViewer, SystemState } from './systemSlice';
|
import { setShouldShowLogViewer, SystemState } from './systemSlice';
|
||||||
import { useLayoutEffect, useRef, useState } from 'react';
|
import { useLayoutEffect, useRef, useState } from 'react';
|
||||||
@ -14,112 +14,138 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
const logSelector = createSelector(
|
const logSelector = createSelector(
|
||||||
(state: RootState) => state.system,
|
(state: RootState) => state.system,
|
||||||
(system: SystemState) => system.log,
|
(system: SystemState) => system.log,
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
resultEqualityCheck: (a, b) => a.length === b.length,
|
// We don't need a deep equality check for this selector.
|
||||||
},
|
resultEqualityCheck: (a, b) => a.length === b.length,
|
||||||
}
|
},
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
const systemSelector = createSelector(
|
||||||
(state: RootState) => state.system,
|
(state: RootState) => state.system,
|
||||||
(system: SystemState) => {
|
(system: SystemState) => {
|
||||||
return { shouldShowLogViewer: system.shouldShowLogViewer };
|
return { shouldShowLogViewer: system.shouldShowLogViewer };
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
},
|
},
|
||||||
{
|
}
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic log viewer, floats on bottom of page.
|
||||||
|
*/
|
||||||
const LogViewer = () => {
|
const LogViewer = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const bg = useColorModeValue('gray.50', 'gray.900');
|
const log = useAppSelector(logSelector);
|
||||||
const borderColor = useColorModeValue('gray.500', 'gray.500');
|
const { shouldShowLogViewer } = useAppSelector(systemSelector);
|
||||||
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
|
|
||||||
|
|
||||||
const log = useAppSelector(logSelector);
|
// Set colors based on dark/light mode
|
||||||
const { shouldShowLogViewer } = useAppSelector(systemSelector);
|
const bg = useColorModeValue('gray.50', 'gray.900');
|
||||||
|
const borderColor = useColorModeValue('gray.500', 'gray.500');
|
||||||
|
const logTextColors = useColorModeValue(
|
||||||
|
{
|
||||||
|
info: undefined,
|
||||||
|
warning: 'yellow.500',
|
||||||
|
error: 'red.500',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
info: undefined,
|
||||||
|
warning: 'yellow.300',
|
||||||
|
error: 'red.300',
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
const viewerRef = useRef<HTMLDivElement>(null);
|
// Rudimentary autoscroll
|
||||||
|
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
|
||||||
|
const viewerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
/**
|
||||||
if (viewerRef.current !== null && shouldAutoscroll) {
|
* If autoscroll is on, scroll to the bottom when:
|
||||||
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
|
* - log updates
|
||||||
}
|
* - viewer is toggled
|
||||||
});
|
*
|
||||||
|
* Also scroll to the bottom whenever autoscroll is turned on.
|
||||||
|
*/
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
if (viewerRef.current !== null && shouldAutoscroll) {
|
||||||
|
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
|
||||||
|
}
|
||||||
|
}, [shouldAutoscroll, log, shouldShowLogViewer]);
|
||||||
|
|
||||||
return (
|
const handleClickLogViewerToggle = () => {
|
||||||
<>
|
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
|
||||||
{shouldShowLogViewer && (
|
};
|
||||||
<Flex
|
|
||||||
position={'fixed'}
|
return (
|
||||||
left={0}
|
<>
|
||||||
bottom={0}
|
{shouldShowLogViewer && (
|
||||||
height='200px'
|
<Flex
|
||||||
width='100vw'
|
position={'fixed'}
|
||||||
overflow='auto'
|
left={0}
|
||||||
direction='column'
|
bottom={0}
|
||||||
fontFamily='monospace'
|
height="200px" // TODO: Make the log viewer resizeable.
|
||||||
fontSize='sm'
|
width="100vw"
|
||||||
pl={12}
|
overflow="auto"
|
||||||
pr={2}
|
direction="column"
|
||||||
pb={2}
|
fontFamily="monospace"
|
||||||
borderTopWidth='4px'
|
fontSize="sm"
|
||||||
borderColor={borderColor}
|
pl={12}
|
||||||
background={bg}
|
pr={2}
|
||||||
ref={viewerRef}
|
pb={2}
|
||||||
>
|
borderTopWidth="4px"
|
||||||
{log.map((entry, i) => (
|
borderColor={borderColor}
|
||||||
<Flex gap={2} key={i}>
|
background={bg}
|
||||||
<Text fontSize='sm' fontWeight={'semibold'}>
|
ref={viewerRef}
|
||||||
{entry.timestamp}:
|
>
|
||||||
</Text>
|
{log.map((entry, i) => {
|
||||||
<Text fontSize='sm' wordBreak={'break-all'}>
|
const { timestamp, message, level } = entry;
|
||||||
{entry.message}
|
return (
|
||||||
</Text>
|
<Flex gap={2} key={i} textColor={logTextColors[level]}>
|
||||||
</Flex>
|
<Text fontSize="sm" fontWeight={'semibold'}>
|
||||||
))}
|
{timestamp}:
|
||||||
</Flex>
|
</Text>
|
||||||
)}
|
<Text fontSize="sm" wordBreak={'break-all'}>
|
||||||
{shouldShowLogViewer && (
|
{message}
|
||||||
<Tooltip
|
</Text>
|
||||||
label={
|
</Flex>
|
||||||
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
|
);
|
||||||
}
|
})}
|
||||||
>
|
</Flex>
|
||||||
<IconButton
|
)}
|
||||||
size='sm'
|
{shouldShowLogViewer && (
|
||||||
position={'fixed'}
|
<Tooltip label={shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'}>
|
||||||
left={2}
|
<IconButton
|
||||||
bottom={12}
|
size="sm"
|
||||||
aria-label='Toggle autoscroll'
|
position={'fixed'}
|
||||||
variant={'solid'}
|
left={2}
|
||||||
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
|
bottom={12}
|
||||||
icon={<FaAngleDoubleDown />}
|
aria-label="Toggle autoscroll"
|
||||||
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
|
variant={'solid'}
|
||||||
/>
|
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
|
||||||
</Tooltip>
|
icon={<FaAngleDoubleDown />}
|
||||||
)}
|
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
|
||||||
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
|
/>
|
||||||
<IconButton
|
</Tooltip>
|
||||||
size='sm'
|
)}
|
||||||
position={'fixed'}
|
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
|
||||||
left={2}
|
<IconButton
|
||||||
bottom={2}
|
size="sm"
|
||||||
variant={'solid'}
|
position={'fixed'}
|
||||||
aria-label='Toggle Log Viewer'
|
left={2}
|
||||||
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
|
bottom={2}
|
||||||
onClick={() =>
|
variant={'solid'}
|
||||||
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
|
aria-label="Toggle Log Viewer"
|
||||||
}
|
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
|
||||||
/>
|
onClick={handleClickLogViewerToggle}
|
||||||
</Tooltip>
|
/>
|
||||||
</>
|
</Tooltip>
|
||||||
);
|
</>
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default LogViewer;
|
export default LogViewer;
|
||||||
|
38
frontend/src/features/system/ProgressBar.tsx
Normal file
38
frontend/src/features/system/ProgressBar.tsx
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import { Progress } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
currentStep: system.currentStep,
|
||||||
|
totalSteps: system.totalSteps,
|
||||||
|
currentStatusHasSteps: system.currentStatusHasSteps,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const ProgressBar = () => {
|
||||||
|
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
|
||||||
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Progress
|
||||||
|
height="10px"
|
||||||
|
value={value}
|
||||||
|
isIndeterminate={isProcessing && !currentStatusHasSteps}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ProgressBar;
|
@ -1,170 +1,164 @@
|
|||||||
import {
|
import {
|
||||||
Flex,
|
Button,
|
||||||
FormControl,
|
Flex,
|
||||||
FormLabel,
|
FormControl,
|
||||||
Heading,
|
FormLabel,
|
||||||
HStack,
|
Heading,
|
||||||
Modal,
|
HStack,
|
||||||
ModalBody,
|
Modal,
|
||||||
ModalCloseButton,
|
ModalBody,
|
||||||
ModalContent,
|
ModalCloseButton,
|
||||||
ModalFooter,
|
ModalContent,
|
||||||
ModalHeader,
|
ModalFooter,
|
||||||
ModalOverlay,
|
ModalHeader,
|
||||||
Switch,
|
ModalOverlay,
|
||||||
Text,
|
Switch,
|
||||||
useDisclosure,
|
Text,
|
||||||
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from '../../app/hooks';
|
import { useAppDispatch, useAppSelector } from '../../app/store';
|
||||||
import {
|
import {
|
||||||
setShouldConfirmOnDelete,
|
setShouldConfirmOnDelete,
|
||||||
setShouldDisplayInProgress,
|
setShouldDisplayInProgress,
|
||||||
SystemState,
|
SystemState,
|
||||||
} from './systemSlice';
|
} from './systemSlice';
|
||||||
import { RootState } from '../../app/store';
|
import { RootState } from '../../app/store';
|
||||||
import SDButton from '../../components/SDButton';
|
|
||||||
import { persistor } from '../../main';
|
import { persistor } from '../../main';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
import { cloneElement, ReactElement } from 'react';
|
import { cloneElement, ReactElement } from 'react';
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
const systemSelector = createSelector(
|
||||||
(state: RootState) => state.system,
|
(state: RootState) => state.system,
|
||||||
(system: SystemState) => {
|
(system: SystemState) => {
|
||||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
|
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
|
||||||
return { shouldDisplayInProgress, shouldConfirmOnDelete };
|
return { shouldDisplayInProgress, shouldConfirmOnDelete };
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: { resultEqualityCheck: isEqual },
|
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
type Props = {
|
type SettingsModalProps = {
|
||||||
children: ReactElement;
|
/* The button to open the Settings Modal */
|
||||||
|
children: ReactElement;
|
||||||
};
|
};
|
||||||
|
|
||||||
const SettingsModal = ({ children }: Props) => {
|
/**
|
||||||
const {
|
* Modal for app settings. Also provides Reset functionality in which the
|
||||||
isOpen: isSettingsModalOpen,
|
* app's localstorage is wiped via redux-persist.
|
||||||
onOpen: onSettingsModalOpen,
|
*
|
||||||
onClose: onSettingsModalClose,
|
* Secondary post-reset modal is included here.
|
||||||
} = useDisclosure();
|
*/
|
||||||
|
const SettingsModal = ({ children }: SettingsModalProps) => {
|
||||||
|
const {
|
||||||
|
isOpen: isSettingsModalOpen,
|
||||||
|
onOpen: onSettingsModalOpen,
|
||||||
|
onClose: onSettingsModalClose,
|
||||||
|
} = useDisclosure();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
isOpen: isRefreshModalOpen,
|
isOpen: isRefreshModalOpen,
|
||||||
onOpen: onRefreshModalOpen,
|
onOpen: onRefreshModalOpen,
|
||||||
onClose: onRefreshModalClose,
|
onClose: onRefreshModalClose,
|
||||||
} = useDisclosure();
|
} = useDisclosure();
|
||||||
|
|
||||||
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
|
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
|
||||||
useAppSelector(systemSelector);
|
useAppSelector(systemSelector);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleClickResetWebUI = () => {
|
/**
|
||||||
persistor.purge().then(() => {
|
* Resets localstorage, then opens a secondary modal informing user to
|
||||||
onSettingsModalClose();
|
* refresh their browser.
|
||||||
onRefreshModalOpen();
|
* */
|
||||||
});
|
const handleClickResetWebUI = () => {
|
||||||
};
|
persistor.purge().then(() => {
|
||||||
|
onSettingsModalClose();
|
||||||
|
onRefreshModalOpen();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{cloneElement(children, {
|
{cloneElement(children, {
|
||||||
onClick: onSettingsModalOpen,
|
onClick: onSettingsModalOpen,
|
||||||
})}
|
})}
|
||||||
|
|
||||||
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
|
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
|
||||||
<ModalOverlay />
|
<ModalOverlay />
|
||||||
<ModalContent>
|
<ModalContent>
|
||||||
<ModalHeader>Settings</ModalHeader>
|
<ModalHeader>Settings</ModalHeader>
|
||||||
<ModalCloseButton />
|
<ModalCloseButton />
|
||||||
<ModalBody>
|
<ModalBody>
|
||||||
<Flex gap={5} direction='column'>
|
<Flex gap={5} direction="column">
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<HStack>
|
<HStack>
|
||||||
<FormLabel marginBottom={1}>
|
<FormLabel marginBottom={1}>
|
||||||
Display in-progress images (slower)
|
Display in-progress images (slower)
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
<Switch
|
<Switch
|
||||||
isChecked={shouldDisplayInProgress}
|
isChecked={shouldDisplayInProgress}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
dispatch(
|
dispatch(setShouldDisplayInProgress(e.target.checked))
|
||||||
setShouldDisplayInProgress(
|
}
|
||||||
e.target.checked
|
/>
|
||||||
)
|
</HStack>
|
||||||
)
|
</FormControl>
|
||||||
}
|
<FormControl>
|
||||||
/>
|
<HStack>
|
||||||
</HStack>
|
<FormLabel marginBottom={1}>Confirm on delete</FormLabel>
|
||||||
</FormControl>
|
<Switch
|
||||||
<FormControl>
|
isChecked={shouldConfirmOnDelete}
|
||||||
<HStack>
|
onChange={(e) =>
|
||||||
<FormLabel marginBottom={1}>
|
dispatch(setShouldConfirmOnDelete(e.target.checked))
|
||||||
Confirm on delete
|
}
|
||||||
</FormLabel>
|
/>
|
||||||
<Switch
|
</HStack>
|
||||||
isChecked={shouldConfirmOnDelete}
|
</FormControl>
|
||||||
onChange={(e) =>
|
|
||||||
dispatch(
|
|
||||||
setShouldConfirmOnDelete(
|
|
||||||
e.target.checked
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</HStack>
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Heading size={'md'}>Reset Web UI</Heading>
|
<Heading size={'md'}>Reset Web UI</Heading>
|
||||||
<Text>
|
<Text>
|
||||||
Resetting the web UI only resets the browser's
|
Resetting the web UI only resets the browser's local cache of
|
||||||
local cache of your images and remembered
|
your images and remembered settings. It does not delete any
|
||||||
settings. It does not delete any images from
|
images from disk.
|
||||||
disk.
|
</Text>
|
||||||
</Text>
|
<Text>
|
||||||
<Text>
|
If images aren't showing up in the gallery or something else
|
||||||
If images aren't showing up in the gallery or
|
isn't working, please try resetting before submitting an issue
|
||||||
something else isn't working, please try
|
on GitHub.
|
||||||
resetting before submitting an issue on GitHub.
|
</Text>
|
||||||
</Text>
|
<Button colorScheme="red" onClick={handleClickResetWebUI}>
|
||||||
<SDButton
|
Reset Web UI
|
||||||
label='Reset Web UI'
|
</Button>
|
||||||
colorScheme='red'
|
</Flex>
|
||||||
onClick={handleClickResetWebUI}
|
</ModalBody>
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
</ModalBody>
|
|
||||||
|
|
||||||
<ModalFooter>
|
<ModalFooter>
|
||||||
<SDButton
|
<Button onClick={onSettingsModalClose}>Close</Button>
|
||||||
label='Close'
|
</ModalFooter>
|
||||||
onClick={onSettingsModalClose}
|
</ModalContent>
|
||||||
/>
|
</Modal>
|
||||||
</ModalFooter>
|
|
||||||
</ModalContent>
|
|
||||||
</Modal>
|
|
||||||
|
|
||||||
<Modal
|
<Modal
|
||||||
closeOnOverlayClick={false}
|
closeOnOverlayClick={false}
|
||||||
isOpen={isRefreshModalOpen}
|
isOpen={isRefreshModalOpen}
|
||||||
onClose={onRefreshModalClose}
|
onClose={onRefreshModalClose}
|
||||||
isCentered
|
isCentered
|
||||||
>
|
>
|
||||||
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
|
<ModalOverlay bg="blackAlpha.300" backdropFilter="blur(40px)" />
|
||||||
<ModalContent>
|
<ModalContent>
|
||||||
<ModalBody pb={6} pt={6}>
|
<ModalBody pb={6} pt={6}>
|
||||||
<Flex justifyContent={'center'}>
|
<Flex justifyContent={'center'}>
|
||||||
<Text fontSize={'lg'}>
|
<Text fontSize={'lg'}>
|
||||||
Web UI has been reset. Refresh the page to
|
Web UI has been reset. Refresh the page to reload.
|
||||||
reload.
|
</Text>
|
||||||
</Text>
|
</Flex>
|
||||||
</Flex>
|
</ModalBody>
|
||||||
</ModalBody>
|
</ModalContent>
|
||||||
</ModalContent>
|
</Modal>
|
||||||
</Modal>
|
</>
|
||||||
</>
|
);
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default SettingsModal;
|
export default SettingsModal;
|
||||||
|
120
frontend/src/features/system/SiteHeader.tsx
Normal file
120
frontend/src/features/system/SiteHeader.tsx
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import {
|
||||||
|
Flex,
|
||||||
|
Heading,
|
||||||
|
IconButton,
|
||||||
|
Link,
|
||||||
|
Spacer,
|
||||||
|
Text,
|
||||||
|
useColorMode,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
|
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
|
||||||
|
import { MdHelp, MdSettings } from 'react-icons/md';
|
||||||
|
import { useAppSelector } from '../../app/store';
|
||||||
|
import { RootState } from '../../app/store';
|
||||||
|
import SettingsModal from '../system/SettingsModal';
|
||||||
|
import { SystemState } from '../system/systemSlice';
|
||||||
|
const systemSelector = createSelector(
|
||||||
|
(state: RootState) => state.system,
|
||||||
|
(system: SystemState) => {
|
||||||
|
return {
|
||||||
|
isConnected: system.isConnected,
|
||||||
|
isProcessing: system.isProcessing,
|
||||||
|
currentIteration: system.currentIteration,
|
||||||
|
totalIterations: system.totalIterations,
|
||||||
|
currentStatus: system.currentStatus,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: { resultEqualityCheck: isEqual },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Header, includes color mode toggle, settings button, status message.
|
||||||
|
*/
|
||||||
|
const SiteHeader = () => {
|
||||||
|
const { colorMode, toggleColorMode } = useColorMode();
|
||||||
|
const {
|
||||||
|
isConnected,
|
||||||
|
isProcessing,
|
||||||
|
currentIteration,
|
||||||
|
totalIterations,
|
||||||
|
currentStatus,
|
||||||
|
} = useAppSelector(systemSelector);
|
||||||
|
|
||||||
|
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
|
||||||
|
|
||||||
|
const colorModeIcon = colorMode == 'light' ? <FaMoon /> : <FaSun />;
|
||||||
|
|
||||||
|
// Make FaMoon and FaSun icon apparent size consistent
|
||||||
|
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
|
||||||
|
|
||||||
|
let statusMessage = currentStatus;
|
||||||
|
|
||||||
|
if (isProcessing) {
|
||||||
|
if (totalIterations > 1) {
|
||||||
|
statusMessage += ` [${currentIteration}/${totalIterations}]`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
|
||||||
|
<Heading size={'lg'}>InvokeUI</Heading>
|
||||||
|
|
||||||
|
<Spacer />
|
||||||
|
|
||||||
|
<Text textColor={statusMessageTextColor}>{statusMessage}</Text>
|
||||||
|
|
||||||
|
<SettingsModal>
|
||||||
|
<IconButton
|
||||||
|
aria-label="Settings"
|
||||||
|
variant="link"
|
||||||
|
fontSize={24}
|
||||||
|
size={'sm'}
|
||||||
|
icon={<MdSettings />}
|
||||||
|
/>
|
||||||
|
</SettingsModal>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label="Link to Github Issues"
|
||||||
|
variant="link"
|
||||||
|
fontSize={23}
|
||||||
|
size={'sm'}
|
||||||
|
icon={
|
||||||
|
<Link
|
||||||
|
isExternal
|
||||||
|
href="http://github.com/lstein/stable-diffusion/issues"
|
||||||
|
>
|
||||||
|
<MdHelp />
|
||||||
|
</Link>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label="Link to Github Repo"
|
||||||
|
variant="link"
|
||||||
|
fontSize={20}
|
||||||
|
size={'sm'}
|
||||||
|
icon={
|
||||||
|
<Link isExternal href="http://github.com/lstein/stable-diffusion">
|
||||||
|
<FaGithub />
|
||||||
|
</Link>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<IconButton
|
||||||
|
aria-label="Toggle Dark Mode"
|
||||||
|
onClick={toggleColorMode}
|
||||||
|
variant="link"
|
||||||
|
size={'sm'}
|
||||||
|
fontSize={colorModeIconFontSize}
|
||||||
|
icon={colorModeIcon}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SiteHeader;
|
@ -1,10 +1,13 @@
|
|||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import dateFormat from 'dateformat';
|
|
||||||
import { ExpandedIndex } from '@chakra-ui/react';
|
import { ExpandedIndex } from '@chakra-ui/react';
|
||||||
|
import * as InvokeAI from '../../app/invokeai'
|
||||||
|
|
||||||
|
export type LogLevel = 'info' | 'warning' | 'error';
|
||||||
|
|
||||||
export interface LogEntry {
|
export interface LogEntry {
|
||||||
timestamp: string;
|
timestamp: string;
|
||||||
|
level: LogLevel;
|
||||||
message: string;
|
message: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12,10 +15,8 @@ export interface Log {
|
|||||||
[index: number]: LogEntry;
|
[index: number]: LogEntry;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SystemState {
|
export interface SystemState extends InvokeAI.SystemStatus, InvokeAI.SystemConfig {
|
||||||
shouldDisplayInProgress: boolean;
|
shouldDisplayInProgress: boolean;
|
||||||
isProcessing: boolean;
|
|
||||||
currentStep: number;
|
|
||||||
log: Array<LogEntry>;
|
log: Array<LogEntry>;
|
||||||
shouldShowLogViewer: boolean;
|
shouldShowLogViewer: boolean;
|
||||||
isGFPGANAvailable: boolean;
|
isGFPGANAvailable: boolean;
|
||||||
@ -24,12 +25,17 @@ export interface SystemState {
|
|||||||
socketId: string;
|
socketId: string;
|
||||||
shouldConfirmOnDelete: boolean;
|
shouldConfirmOnDelete: boolean;
|
||||||
openAccordions: ExpandedIndex;
|
openAccordions: ExpandedIndex;
|
||||||
|
currentStep: number;
|
||||||
|
totalSteps: number;
|
||||||
|
currentIteration: number;
|
||||||
|
totalIterations: number;
|
||||||
|
currentStatus: string;
|
||||||
|
currentStatusHasSteps: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialSystemState = {
|
const initialSystemState = {
|
||||||
isConnected: false,
|
isConnected: false,
|
||||||
isProcessing: false,
|
isProcessing: false,
|
||||||
currentStep: 0,
|
|
||||||
log: [],
|
log: [],
|
||||||
shouldShowLogViewer: false,
|
shouldShowLogViewer: false,
|
||||||
shouldDisplayInProgress: false,
|
shouldDisplayInProgress: false,
|
||||||
@ -38,6 +44,17 @@ const initialSystemState = {
|
|||||||
socketId: '',
|
socketId: '',
|
||||||
shouldConfirmOnDelete: true,
|
shouldConfirmOnDelete: true,
|
||||||
openAccordions: [0],
|
openAccordions: [0],
|
||||||
|
currentStep: 0,
|
||||||
|
totalSteps: 0,
|
||||||
|
currentIteration: 0,
|
||||||
|
totalIterations: 0,
|
||||||
|
currentStatus: '',
|
||||||
|
currentStatusHasSteps: false,
|
||||||
|
model: '',
|
||||||
|
model_id: '',
|
||||||
|
model_hash: '',
|
||||||
|
app_id: '',
|
||||||
|
app_version: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialState: SystemState = initialSystemState;
|
const initialState: SystemState = initialSystemState;
|
||||||
@ -51,18 +68,35 @@ export const systemSlice = createSlice({
|
|||||||
},
|
},
|
||||||
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
setIsProcessing: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isProcessing = action.payload;
|
state.isProcessing = action.payload;
|
||||||
if (action.payload === false) {
|
|
||||||
state.currentStep = 0;
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
setCurrentStep: (state, action: PayloadAction<number>) => {
|
setCurrentStatus: (state, action: PayloadAction<string>) => {
|
||||||
state.currentStep = action.payload;
|
state.currentStatus = action.payload;
|
||||||
},
|
},
|
||||||
addLogEntry: (state, action: PayloadAction<string>) => {
|
setSystemStatus: (state, action: PayloadAction<InvokeAI.SystemStatus>) => {
|
||||||
|
const currentStatus =
|
||||||
|
!action.payload.isProcessing && state.isConnected
|
||||||
|
? 'Connected'
|
||||||
|
: action.payload.currentStatus;
|
||||||
|
|
||||||
|
return { ...state, ...action.payload, currentStatus };
|
||||||
|
},
|
||||||
|
addLogEntry: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{
|
||||||
|
timestamp: string;
|
||||||
|
message: string;
|
||||||
|
level?: LogLevel;
|
||||||
|
}>
|
||||||
|
) => {
|
||||||
|
const { timestamp, message, level } = action.payload;
|
||||||
|
const logLevel = level || 'info';
|
||||||
|
|
||||||
const entry: LogEntry = {
|
const entry: LogEntry = {
|
||||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp,
|
||||||
message: action.payload,
|
message,
|
||||||
|
level: logLevel,
|
||||||
};
|
};
|
||||||
|
|
||||||
state.log.push(entry);
|
state.log.push(entry);
|
||||||
},
|
},
|
||||||
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
|
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
|
||||||
@ -80,19 +114,24 @@ export const systemSlice = createSlice({
|
|||||||
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
|
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
|
||||||
state.openAccordions = action.payload;
|
state.openAccordions = action.payload;
|
||||||
},
|
},
|
||||||
|
setSystemConfig: (state, action: PayloadAction<InvokeAI.SystemConfig>) => {
|
||||||
|
return { ...state, ...action.payload };
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
setShouldDisplayInProgress,
|
setShouldDisplayInProgress,
|
||||||
setIsProcessing,
|
setIsProcessing,
|
||||||
setCurrentStep,
|
|
||||||
addLogEntry,
|
addLogEntry,
|
||||||
setShouldShowLogViewer,
|
setShouldShowLogViewer,
|
||||||
setIsConnected,
|
setIsConnected,
|
||||||
setSocketId,
|
setSocketId,
|
||||||
setShouldConfirmOnDelete,
|
setShouldConfirmOnDelete,
|
||||||
setOpenAccordions,
|
setOpenAccordions,
|
||||||
|
setSystemStatus,
|
||||||
|
setCurrentStatus,
|
||||||
|
setSystemConfig,
|
||||||
} = systemSlice.actions;
|
} = systemSlice.actions;
|
||||||
|
|
||||||
export default systemSlice.reducer;
|
export default systemSlice.reducer;
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { isEqual } from 'lodash';
|
|
||||||
import { useMemo } from 'react';
|
|
||||||
import { useAppSelector } from '../../app/hooks';
|
|
||||||
import { RootState } from '../../app/store';
|
|
||||||
import { SDState } from '../sd/sdSlice';
|
|
||||||
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
|
|
||||||
import { SystemState } from './systemSlice';
|
|
||||||
|
|
||||||
const sdSelector = createSelector(
|
|
||||||
(state: RootState) => state.sd,
|
|
||||||
(sd: SDState) => {
|
|
||||||
return {
|
|
||||||
prompt: sd.prompt,
|
|
||||||
shouldGenerateVariations: sd.shouldGenerateVariations,
|
|
||||||
seedWeights: sd.seedWeights,
|
|
||||||
maskPath: sd.maskPath,
|
|
||||||
initialImagePath: sd.initialImagePath,
|
|
||||||
seed: sd.seed,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const systemSelector = createSelector(
|
|
||||||
(state: RootState) => state.system,
|
|
||||||
(system: SystemState) => {
|
|
||||||
return {
|
|
||||||
isProcessing: system.isProcessing,
|
|
||||||
isConnected: system.isConnected,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
/*
|
|
||||||
Checks relevant pieces of state to confirm generation will not deterministically fail.
|
|
||||||
|
|
||||||
This is used to prevent the 'Generate' button from being clicked.
|
|
||||||
|
|
||||||
Other parameter values may cause failure but we rely on input validation for those.
|
|
||||||
*/
|
|
||||||
const useCheckParameters = () => {
|
|
||||||
const {
|
|
||||||
prompt,
|
|
||||||
shouldGenerateVariations,
|
|
||||||
seedWeights,
|
|
||||||
maskPath,
|
|
||||||
initialImagePath,
|
|
||||||
seed,
|
|
||||||
} = useAppSelector(sdSelector);
|
|
||||||
|
|
||||||
const { isProcessing, isConnected } = useAppSelector(systemSelector);
|
|
||||||
|
|
||||||
return useMemo(() => {
|
|
||||||
// Cannot generate without a prompt
|
|
||||||
if (!prompt) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot generate with a mask without img2img
|
|
||||||
if (maskPath && !initialImagePath) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: job queue
|
|
||||||
// Cannot generate if already processing an image
|
|
||||||
if (isProcessing) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot generate if not connected
|
|
||||||
if (!isConnected) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot generate variations without valid seed weights
|
|
||||||
if (
|
|
||||||
shouldGenerateVariations &&
|
|
||||||
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
|
|
||||||
seed === -1)
|
|
||||||
) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// All good
|
|
||||||
return true;
|
|
||||||
}, [
|
|
||||||
prompt,
|
|
||||||
maskPath,
|
|
||||||
initialImagePath,
|
|
||||||
isProcessing,
|
|
||||||
isConnected,
|
|
||||||
shouldGenerateVariations,
|
|
||||||
seedWeights,
|
|
||||||
seed,
|
|
||||||
]);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default useCheckParameters;
|
|
@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
|
|||||||
|
|
||||||
export const persistor = persistStore(store);
|
export const persistor = persistStore(store);
|
||||||
|
|
||||||
import App from './App';
|
|
||||||
import { theme } from './app/theme';
|
import { theme } from './app/theme';
|
||||||
import Loading from './Loading';
|
import Loading from './Loading';
|
||||||
|
import App from './app/App';
|
||||||
|
|
||||||
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
|
@ -419,6 +419,13 @@
|
|||||||
compute-scroll-into-view "1.0.14"
|
compute-scroll-into-view "1.0.14"
|
||||||
copy-to-clipboard "3.3.1"
|
copy-to-clipboard "3.3.1"
|
||||||
|
|
||||||
|
"@chakra-ui/icon@3.0.10":
|
||||||
|
version "3.0.10"
|
||||||
|
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.10.tgz#1a11b5edb42a8af7aa5b6dec2bf2c6c4df1869fc"
|
||||||
|
integrity sha512-utO569d9bptEraJrEhuImfNzQ8v+a8PsQh8kTsodCzg8B16R3t5TTuoqeJqS6Nq16Vq6w87QbX3/4A73CNK5fw==
|
||||||
|
dependencies:
|
||||||
|
"@chakra-ui/shared-utils" "2.0.1"
|
||||||
|
|
||||||
"@chakra-ui/icon@3.0.9":
|
"@chakra-ui/icon@3.0.9":
|
||||||
version "3.0.9"
|
version "3.0.9"
|
||||||
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744"
|
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744"
|
||||||
@ -426,6 +433,13 @@
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@chakra-ui/shared-utils" "2.0.1"
|
"@chakra-ui/shared-utils" "2.0.1"
|
||||||
|
|
||||||
|
"@chakra-ui/icons@^2.0.10":
|
||||||
|
version "2.0.10"
|
||||||
|
resolved "https://registry.yarnpkg.com/@chakra-ui/icons/-/icons-2.0.10.tgz#61aeb44c913c10e7ff77addc798494e50d66c760"
|
||||||
|
integrity sha512-hxMspvysOay2NsJyadM611F/Y4vVzJU/YkXTxsyBjm6v/DbENhpVmPnUf+kwwyl7dINNb9iOF+kuGxnuIEO1Tw==
|
||||||
|
dependencies:
|
||||||
|
"@chakra-ui/icon" "3.0.10"
|
||||||
|
|
||||||
"@chakra-ui/image@2.0.10":
|
"@chakra-ui/image@2.0.10":
|
||||||
version "2.0.10"
|
version "2.0.10"
|
||||||
resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8"
|
resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8"
|
||||||
|
@ -100,6 +100,13 @@ SAMPLER_CHOICES = [
|
|||||||
'plms',
|
'plms',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PRECISION_CHOICES = [
|
||||||
|
'auto',
|
||||||
|
'float32',
|
||||||
|
'autocast',
|
||||||
|
'float16',
|
||||||
|
]
|
||||||
|
|
||||||
# is there a way to pick this up during git commits?
|
# is there a way to pick this up during git commits?
|
||||||
APP_ID = 'lstein/stable-diffusion'
|
APP_ID = 'lstein/stable-diffusion'
|
||||||
APP_VERSION = 'v1.15'
|
APP_VERSION = 'v1.15'
|
||||||
@ -174,31 +181,37 @@ class Args(object):
|
|||||||
switches.append(f'-W {a["width"]}')
|
switches.append(f'-W {a["width"]}')
|
||||||
switches.append(f'-H {a["height"]}')
|
switches.append(f'-H {a["height"]}')
|
||||||
switches.append(f'-C {a["cfg_scale"]}')
|
switches.append(f'-C {a["cfg_scale"]}')
|
||||||
switches.append(f'-A {a["sampler_name"]}')
|
|
||||||
if a['grid']:
|
if a['grid']:
|
||||||
switches.append('--grid')
|
switches.append('--grid')
|
||||||
if a['seamless']:
|
if a['seamless']:
|
||||||
switches.append('--seamless')
|
switches.append('--seamless')
|
||||||
|
|
||||||
|
# img2img generations have parameters relevant only to them and have special handling
|
||||||
if a['init_img'] and len(a['init_img'])>0:
|
if a['init_img'] and len(a['init_img'])>0:
|
||||||
switches.append(f'-I {a["init_img"]}')
|
switches.append(f'-I {a["init_img"]}')
|
||||||
if a['init_mask'] and len(a['init_mask'])>0:
|
switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||||
switches.append(f'-M {a["init_mask"]}')
|
if a['fit']:
|
||||||
if a['init_color'] and len(a['init_color'])>0:
|
switches.append(f'--fit')
|
||||||
switches.append(f'--init_color {a["init_color"]}')
|
if a['init_mask'] and len(a['init_mask'])>0:
|
||||||
if a['fit']:
|
switches.append(f'-M {a["init_mask"]}')
|
||||||
switches.append(f'--fit')
|
if a['init_color'] and len(a['init_color'])>0:
|
||||||
if a['init_img'] and a['strength'] and a['strength']>0:
|
switches.append(f'--init_color {a["init_color"]}')
|
||||||
switches.append(f'-f {a["strength"]}')
|
if a['strength'] and a['strength']>0:
|
||||||
|
switches.append(f'-f {a["strength"]}')
|
||||||
|
else:
|
||||||
|
switches.append(f'-A {a["sampler_name"]}')
|
||||||
|
|
||||||
|
# gfpgan-specific parameters
|
||||||
if a['gfpgan_strength']:
|
if a['gfpgan_strength']:
|
||||||
switches.append(f'-G {a["gfpgan_strength"]}')
|
switches.append(f'-G {a["gfpgan_strength"]}')
|
||||||
|
|
||||||
|
# esrgan-specific parameters
|
||||||
if a['upscale']:
|
if a['upscale']:
|
||||||
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
|
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
|
||||||
if a['embiggen']:
|
if a['embiggen']:
|
||||||
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
|
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
|
||||||
if a['embiggen_tiles']:
|
if a['embiggen_tiles']:
|
||||||
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
|
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
|
||||||
if a['variation_amount'] > 0:
|
|
||||||
switches.append(f'-v {a["variation_amount"]}')
|
|
||||||
if a['with_variations']:
|
if a['with_variations']:
|
||||||
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
|
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
|
||||||
switches.append(f'-V {formatted_variations}')
|
switches.append(f'-V {formatted_variations}')
|
||||||
@ -316,7 +329,16 @@ class Args(object):
|
|||||||
'--full_precision',
|
'--full_precision',
|
||||||
dest='full_precision',
|
dest='full_precision',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Use more memory-intensive full precision math for calculations',
|
help='Deprecated way to set --precision=float32',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--precision',
|
||||||
|
dest='precision',
|
||||||
|
type=str,
|
||||||
|
choices=PRECISION_CHOICES,
|
||||||
|
metavar='PRECISION',
|
||||||
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
|
default='auto',
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--from_file',
|
'--from_file',
|
||||||
@ -618,18 +640,24 @@ def metadata_dumps(opt,
|
|||||||
postprocessing=postprocessing
|
postprocessing=postprocessing
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: This is just a hack until postprocessing pipeline work completed
|
# 'postprocessing' is either null or an array of postprocessing metadatal
|
||||||
image_dict['postprocessing'] = []
|
if postprocessing:
|
||||||
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
|
# TODO: This is just a hack until postprocessing pipeline work completed
|
||||||
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
|
image_dict['postprocessing'] = []
|
||||||
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
|
|
||||||
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
|
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
|
||||||
|
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
|
||||||
|
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
|
||||||
|
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
|
||||||
|
else:
|
||||||
|
image_dict['postprocessing'] = None
|
||||||
|
|
||||||
# remove any image keys not mentioned in RFC #266
|
# remove any image keys not mentioned in RFC #266
|
||||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||||
'cfg_scale','step_number','width','height','extra','strength']
|
'cfg_scale','step_number','width','height','extra','strength']
|
||||||
|
|
||||||
rfc_dict ={}
|
rfc_dict ={}
|
||||||
|
|
||||||
for item in image_dict.items():
|
for item in image_dict.items():
|
||||||
key,value = item
|
key,value = item
|
||||||
if key in rfc266_img_fields:
|
if key in rfc266_img_fields:
|
||||||
@ -644,18 +672,17 @@ def metadata_dumps(opt,
|
|||||||
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
|
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
|
||||||
rfc_dict['prompt'] = subprompts
|
rfc_dict['prompt'] = subprompts
|
||||||
|
|
||||||
# variations
|
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
|
||||||
if opt.with_variations:
|
rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else []
|
||||||
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
|
|
||||||
rfc_dict['variations'] = variations
|
|
||||||
|
|
||||||
if opt.init_img:
|
if opt.init_img:
|
||||||
rfc_dict['type'] = 'img2img'
|
rfc_dict['type'] = 'img2img'
|
||||||
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
|
||||||
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
|
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
|
||||||
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
|
||||||
else:
|
else:
|
||||||
rfc_dict['type'] = 'txt2img'
|
rfc_dict['type'] = 'txt2img'
|
||||||
|
rfc_dict.pop('strength')
|
||||||
|
|
||||||
if len(seeds)==0 and opt.seed:
|
if len(seeds)==0 and opt.seed:
|
||||||
seeds=[seed]
|
seeds=[seed]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
def choose_torch_device() -> str:
|
def choose_torch_device() -> str:
|
||||||
'''Convenience routine for guessing which GPU device to run model on'''
|
'''Convenience routine for guessing which GPU device to run model on'''
|
||||||
@ -10,15 +10,18 @@ def choose_torch_device() -> str:
|
|||||||
return 'mps'
|
return 'mps'
|
||||||
return 'cpu'
|
return 'cpu'
|
||||||
|
|
||||||
def choose_autocast_device(device):
|
def choose_precision(device) -> str:
|
||||||
'''Returns an autocast compatible device from a torch device'''
|
'''Returns an appropriate precision for the given torch device'''
|
||||||
device_type = device.type # this returns 'mps' on M1
|
if device.type == 'cuda':
|
||||||
# autocast only for cuda, but GTX 16xx have issues with it
|
device_name = torch.cuda.get_device_name(device)
|
||||||
if device_type == 'cuda':
|
if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
|
||||||
device_name = torch.cuda.get_device_name()
|
return 'float16'
|
||||||
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
|
return 'float32'
|
||||||
return device_type,nullcontext
|
|
||||||
else:
|
def choose_autocast(precision):
|
||||||
return device_type,autocast
|
'''Returns an autocast context or nullcontext for the given precision string'''
|
||||||
else:
|
# float16 currently requires autocast to avoid errors like:
|
||||||
return 'cpu',nullcontext
|
# 'expected scalar type Half but found Float'
|
||||||
|
if precision == 'autocast' or precision == 'float16':
|
||||||
|
return autocast
|
||||||
|
return nullcontext
|
||||||
|
@ -9,13 +9,14 @@ from tqdm import tqdm, trange
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from ldm.dream.devices import choose_autocast_device
|
from ldm.dream.devices import choose_autocast
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
class Generator():
|
class Generator():
|
||||||
def __init__(self,model):
|
def __init__(self, model, precision):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
self.latent_channels = model.channels
|
self.latent_channels = model.channels
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
@ -38,7 +39,7 @@ class Generator():
|
|||||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None,
|
image_callback=None, step_callback=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
device_type,scope = choose_autocast_device(self.model.device)
|
scope = choose_autocast(self.precision)
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
init_image = init_image,
|
init_image = init_image,
|
||||||
@ -51,7 +52,7 @@ class Generator():
|
|||||||
results = []
|
results = []
|
||||||
seed = seed if seed else self.new_seed()
|
seed = seed if seed else self.new_seed()
|
||||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||||
with scope(device_type), self.model.ema_scope():
|
with scope(self.model.device.type), self.model.ema_scope():
|
||||||
for n in trange(iterations, desc='Generating'):
|
for n in trange(iterations, desc='Generating'):
|
||||||
x_T = None
|
x_T = None
|
||||||
if self.variation_amount > 0:
|
if self.variation_amount > 0:
|
||||||
|
@ -11,8 +11,8 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
|||||||
from ldm.dream.generator.img2img import Img2Img
|
from ldm.dream.generator.img2img import Img2Img
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self,model):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model)
|
super().__init__(model, precision)
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -4,13 +4,13 @@ ldm.dream.generator.img2img descends from ldm.dream.generator
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ldm.dream.devices import choose_autocast_device
|
from ldm.dream.devices import choose_autocast
|
||||||
from ldm.dream.generator.base import Generator
|
from ldm.dream.generator.base import Generator
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
def __init__(self,model):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model)
|
super().__init__(model, precision)
|
||||||
self.init_latent = None # by get_noise()
|
self.init_latent = None # by get_noise()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -32,8 +32,8 @@ class Img2Img(Generator):
|
|||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
device_type,scope = choose_autocast_device(self.model.device)
|
scope = choose_autocast(self.precision)
|
||||||
with scope(device_type):
|
with scope(self.model.device.type):
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
self.model.encode_first_stage(init_image)
|
self.model.encode_first_stage(init_image)
|
||||||
) # move to latent space
|
) # move to latent space
|
||||||
|
@ -5,14 +5,14 @@ ldm.dream.generator.inpaint descends from ldm.dream.generator
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.dream.devices import choose_autocast_device
|
from ldm.dream.devices import choose_autocast
|
||||||
from ldm.dream.generator.img2img import Img2Img
|
from ldm.dream.generator.img2img import Img2Img
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def __init__(self,model):
|
def __init__(self, model, precision):
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
super().__init__(model)
|
super().__init__(model, precision)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
@ -38,8 +38,8 @@ class Inpaint(Img2Img):
|
|||||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
device_type,scope = choose_autocast_device(self.model.device)
|
scope = choose_autocast(self.precision)
|
||||||
with scope(device_type):
|
with scope(self.model.device.type):
|
||||||
self.init_latent = self.model.get_first_stage_encoding(
|
self.init_latent = self.model.get_first_stage_encoding(
|
||||||
self.model.encode_first_stage(init_image)
|
self.model.encode_first_stage(init_image)
|
||||||
) # move to latent space
|
) # move to latent space
|
||||||
|
@ -7,8 +7,8 @@ import numpy as np
|
|||||||
from ldm.dream.generator.base import Generator
|
from ldm.dream.generator.base import Generator
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self,model):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model)
|
super().__init__(model, precision)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||||
|
61
ldm/dream/log.py
Normal file
61
ldm/dream/log.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Functions for better format logging
|
||||||
|
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
|
||||||
|
1 write_log_message -- Writes a message to the console
|
||||||
|
2 write_log_files -- Writes a message to files
|
||||||
|
2.1 write_log_default -- File in plain text
|
||||||
|
2.2 write_log_txt -- File in txt format
|
||||||
|
2.3 write_log_markdown -- File in markdown format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def write_log(results, log_path, file_types, output_cntr):
|
||||||
|
"""
|
||||||
|
logs the name of the output image, prompt, and prompt args to the terminal and files
|
||||||
|
"""
|
||||||
|
output_cntr = write_log_message(results, output_cntr)
|
||||||
|
write_log_files(results, log_path, file_types)
|
||||||
|
return output_cntr
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_message(results, output_cntr):
|
||||||
|
"""logs to the terminal"""
|
||||||
|
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||||
|
for l in log_lines:
|
||||||
|
output_cntr += 1
|
||||||
|
print(f"[{output_cntr}] {l}", end="")
|
||||||
|
return output_cntr
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_files(results, log_path, file_types):
|
||||||
|
for file_type in file_types:
|
||||||
|
if file_type == "txt":
|
||||||
|
write_log_txt(log_path, results)
|
||||||
|
elif file_type == "md" or file_type == "markdown":
|
||||||
|
write_log_markdown(log_path, results)
|
||||||
|
else:
|
||||||
|
print(f"'{file_type}' format is not supported, so write in plain text")
|
||||||
|
write_log_default(log_path, results, file_type)
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_default(log_path, results, file_type):
|
||||||
|
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||||
|
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
|
||||||
|
file.writelines(plain_txt_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_txt(log_path, results):
|
||||||
|
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||||
|
with open(log_path + ".txt", "a", encoding="utf-8") as file:
|
||||||
|
file.writelines(txt_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def write_log_markdown(log_path, results):
|
||||||
|
md_lines = []
|
||||||
|
for path, prompt in results:
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
md_lines.append(f"## {file_name}\n\n\n{prompt}\n")
|
||||||
|
with open(log_path + ".md", "a", encoding="utf-8") as file:
|
||||||
|
file.writelines(md_lines)
|
@ -34,6 +34,7 @@ class PngWriter:
|
|||||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||||
# returns full path of output
|
# returns full path of output
|
||||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
|
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
|
||||||
|
print(f'self.outdir={self.outdir}, name={name}')
|
||||||
path = os.path.join(self.outdir, name)
|
path = os.path.join(self.outdir, name)
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text('Dream', dream_prompt)
|
info.add_text('Dream', dream_prompt)
|
||||||
|
@ -29,7 +29,7 @@ from ldm.models.diffusion.plms import PLMSSampler
|
|||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
from ldm.dream.pngwriter import PngWriter
|
from ldm.dream.pngwriter import PngWriter
|
||||||
from ldm.dream.image_util import InitImageResizer
|
from ldm.dream.image_util import InitImageResizer
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_torch_device, choose_precision
|
||||||
from ldm.dream.conditioning import get_uc_and_c
|
from ldm.dream.conditioning import get_uc_and_c
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
@ -104,7 +104,7 @@ gr = Generate(
|
|||||||
# these values are set once and shouldn't be changed
|
# these values are set once and shouldn't be changed
|
||||||
conf = path to configuration file ('configs/models.yaml')
|
conf = path to configuration file ('configs/models.yaml')
|
||||||
model = symbolic name of the model in the configuration file
|
model = symbolic name of the model in the configuration file
|
||||||
full_precision = False
|
precision = float precision to be used
|
||||||
|
|
||||||
# this value is sticky and maintained between generation calls
|
# this value is sticky and maintained between generation calls
|
||||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||||
@ -130,6 +130,7 @@ class Generate:
|
|||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
ddim_eta = 0.0, # deterministic
|
ddim_eta = 0.0, # deterministic
|
||||||
full_precision = False,
|
full_precision = False,
|
||||||
|
precision = 'auto',
|
||||||
# these are deprecated; if present they override values in the conf file
|
# these are deprecated; if present they override values in the conf file
|
||||||
weights = None,
|
weights = None,
|
||||||
config = None,
|
config = None,
|
||||||
@ -145,7 +146,7 @@ class Generate:
|
|||||||
self.cfg_scale = 7.5
|
self.cfg_scale = 7.5
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self.ddim_eta = 0.0 # same seed always produces same image
|
self.ddim_eta = 0.0 # same seed always produces same image
|
||||||
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
|
self.precision = precision
|
||||||
self.strength = 0.75
|
self.strength = 0.75
|
||||||
self.seamless = False
|
self.seamless = False
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
@ -162,6 +163,14 @@ class Generate:
|
|||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
device_type = choose_torch_device()
|
device_type = choose_torch_device()
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
|
if full_precision:
|
||||||
|
if self.precision != 'auto':
|
||||||
|
raise ValueError('Remove --full_precision / -F if using --precision')
|
||||||
|
print('Please remove deprecated --full_precision / -F')
|
||||||
|
print('If auto config does not work you can use --precision=float32')
|
||||||
|
self.precision = 'float32'
|
||||||
|
if self.precision == 'auto':
|
||||||
|
self.precision = choose_precision(self.device)
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||||
@ -440,25 +449,25 @@ class Generate:
|
|||||||
def _make_img2img(self):
|
def _make_img2img(self):
|
||||||
if not self.generators.get('img2img'):
|
if not self.generators.get('img2img'):
|
||||||
from ldm.dream.generator.img2img import Img2Img
|
from ldm.dream.generator.img2img import Img2Img
|
||||||
self.generators['img2img'] = Img2Img(self.model)
|
self.generators['img2img'] = Img2Img(self.model, self.precision)
|
||||||
return self.generators['img2img']
|
return self.generators['img2img']
|
||||||
|
|
||||||
def _make_embiggen(self):
|
def _make_embiggen(self):
|
||||||
if not self.generators.get('embiggen'):
|
if not self.generators.get('embiggen'):
|
||||||
from ldm.dream.generator.embiggen import Embiggen
|
from ldm.dream.generator.embiggen import Embiggen
|
||||||
self.generators['embiggen'] = Embiggen(self.model)
|
self.generators['embiggen'] = Embiggen(self.model, self.precision)
|
||||||
return self.generators['embiggen']
|
return self.generators['embiggen']
|
||||||
|
|
||||||
def _make_txt2img(self):
|
def _make_txt2img(self):
|
||||||
if not self.generators.get('txt2img'):
|
if not self.generators.get('txt2img'):
|
||||||
from ldm.dream.generator.txt2img import Txt2Img
|
from ldm.dream.generator.txt2img import Txt2Img
|
||||||
self.generators['txt2img'] = Txt2Img(self.model)
|
self.generators['txt2img'] = Txt2Img(self.model, self.precision)
|
||||||
return self.generators['txt2img']
|
return self.generators['txt2img']
|
||||||
|
|
||||||
def _make_inpaint(self):
|
def _make_inpaint(self):
|
||||||
if not self.generators.get('inpaint'):
|
if not self.generators.get('inpaint'):
|
||||||
from ldm.dream.generator.inpaint import Inpaint
|
from ldm.dream.generator.inpaint import Inpaint
|
||||||
self.generators['inpaint'] = Inpaint(self.model)
|
self.generators['inpaint'] = Inpaint(self.model, self.precision)
|
||||||
return self.generators['inpaint']
|
return self.generators['inpaint']
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
@ -469,7 +478,7 @@ class Generate:
|
|||||||
model = self._load_model_from_config(self.config, self.weights)
|
model = self._load_model_from_config(self.config, self.weights)
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
model.embedding_manager.load(
|
model.embedding_manager.load(
|
||||||
self.embedding_path, self.full_precision
|
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
|
||||||
)
|
)
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
@ -620,15 +629,12 @@ class Generate:
|
|||||||
model = instantiate_from_config(c.model)
|
model = instantiate_from_config(c.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if self.full_precision:
|
if self.precision == 'float16':
|
||||||
print(
|
print('Using faster float16 precision')
|
||||||
'>> Using slower but more accurate full-precision math (--full_precision)'
|
model.to(torch.float16)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(
|
print('Using more accurate float32 precision')
|
||||||
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
|
|
||||||
)
|
|
||||||
model.half()
|
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[tool.blue]
|
||||||
|
line-length = 90
|
||||||
|
target-version = ['py310']
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user