Merge branch 'development' into mkdocs-updates

This commit is contained in:
Lincoln Stein 2022-09-20 17:11:43 -04:00 committed by GitHub
commit 555f21cd25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
104 changed files with 5425 additions and 4364 deletions

View File

@ -5,8 +5,7 @@ SAMPLES_DIR=${OUT_DIR}
python scripts/dream.py \
--from_file ${PROMPT_FILE} \
--outdir ${OUT_DIR} \
--sampler plms \
--full_precision
--sampler plms
# original output by CompVis/stable-diffusion
IMAGE1=".dev_scripts/images/v1_4_astronaut_rides_horse_plms_step50_seed42.png"

View File

@ -85,9 +85,9 @@ jobs:
fi
# Utterly hacky, but I don't know how else to do this
if [[ ${{ github.ref }} == 'refs/heads/master' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt --full_precision
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/preflight_prompts.txt
elif [[ ${{ github.ref }} == 'refs/heads/development' ]]; then
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt --full_precision
time ${{ steps.vars.outputs.PYTHON_BIN }} scripts/dream.py --from_file tests/dev_prompts.txt
fi
mkdir -p outputs/img-samples
- name: Archive results

View File

@ -86,17 +86,14 @@ You wil need one of the following:
- At least 6 GB of free disk space for the machine learning model, Python, and all its dependencies.
> Note
>
> If you have an Nvidia 10xx series card (e.g. the 1080ti), please run the dream script in
> full-precision mode as shown below.
#### Note
Similarly, specify full-precision mode on Apple M1 hardware.
To run in full-precision mode, start `dream.py` with the `--full_precision` flag:
Precision is auto configured based on the device. If however you encounter
errors like 'expected type Float but found Half' or 'not implemented for Half'
you can try starting `dream.py` with the `--precision=float32` flag:
```bash
(ldm) ~/stable-diffusion$ python scripts/dream.py --full_precision
(ldm) ~/stable-diffusion$ python scripts/dream.py --precision=float32
```
### Features
@ -125,6 +122,11 @@ To run in full-precision mode, start `dream.py` with the `--full_precision` flag
### Latest Changes
- vNEXT (TODO 2022)
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
configure. To switch away from auto use the new flag like `--precision=float32`.
- v1.14 (11 September 2022)
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.

View File

@ -2,14 +2,14 @@ from modules.parse_seed_weights import parse_seed_weights
import argparse
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
"ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
]
@ -20,194 +20,42 @@ def parameters_to_command(params):
switches = list()
if 'prompt' in params:
if "prompt" in params:
switches.append(f'"{params["prompt"]}"')
if 'steps' in params:
if "steps" in params:
switches.append(f'-s {params["steps"]}')
if 'seed' in params:
if "seed" in params:
switches.append(f'-S {params["seed"]}')
if 'width' in params:
if "width" in params:
switches.append(f'-W {params["width"]}')
if 'height' in params:
if "height" in params:
switches.append(f'-H {params["height"]}')
if 'cfg_scale' in params:
if "cfg_scale" in params:
switches.append(f'-C {params["cfg_scale"]}')
if 'sampler_name' in params:
if "sampler_name" in params:
switches.append(f'-A {params["sampler_name"]}')
if 'seamless' in params and params["seamless"] == True:
switches.append(f'--seamless')
if 'init_img' in params and len(params['init_img']) > 0:
if "seamless" in params and params["seamless"] == True:
switches.append(f"--seamless")
if "init_img" in params and len(params["init_img"]) > 0:
switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0:
if "init_mask" in params and len(params["init_mask"]) > 0:
switches.append(f'-M {params["init_mask"]}')
if 'init_color' in params and len(params['init_color']) > 0:
if "init_color" in params and len(params["init_color"]) > 0:
switches.append(f'--init_color {params["init_color"]}')
if 'strength' in params and 'init_img' in params:
if "strength" in params and "init_img" in params:
switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True:
switches.append(f'--fit')
if 'gfpgan_strength' in params and params["gfpgan_strength"]:
if "fit" in params and params["fit"] == True:
switches.append(f"--fit")
if "gfpgan_strength" in params and params["gfpgan_strength"]:
switches.append(f'-G {params["gfpgan_strength"]}')
if 'upscale' in params and params["upscale"]:
if "upscale" in params and params["upscale"]:
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
if 'variation_amount' in params and params['variation_amount'] > 0:
if "variation_amount" in params and params["variation_amount"] > 0:
switches.append(f'-v {params["variation_amount"]}')
if 'with_variations' in params:
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"])
switches.append(f'-V {seed_weight_pairs}')
if "with_variations" in params:
seed_weight_pairs = ",".join(
f"{seed}:{weight}" for seed, weight in params["with_variations"]
)
switches.append(f"-V {seed_weight_pairs}")
return ' '.join(switches)
def create_cmd_parser():
"""
This is simply a copy of the parser from `dream.py` with a change to give
prompt a default value. This is a temporary hack pending merge of #587 which
provides a better way to do this.
"""
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
exit_on_error=True,
)
parser.add_argument('prompt', nargs='?', default='')
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser
return " ".join(switches)

View File

@ -6,7 +6,8 @@ import traceback
import eventlet
import glob
import shlex
import argparse
import math
import shutil
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
@ -15,13 +16,16 @@ from PIL import Image
from pytorch_lightning import logging
from threading import Event
from uuid import uuid4
from send2trash import send2trash
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate
from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.dream.conditioning import split_weighted_subprompts
from modules.parameters import parameters_to_command, create_cmd_parser
from modules.parameters import parameters_to_command
"""
@ -29,12 +33,14 @@ USER CONFIG
"""
output_dir = "outputs/" # Base output directory for images
#host = 'localhost' # Web & socket.io host
host = '0.0.0.0' # Web & socket.io host
# host = 'localhost' # Web & socket.io host
host = "localhost" # Web & socket.io host
port = 9090 # Web & socket.io port
verbose = False # enables copious socket.io logging
additional_allowed_origins = ['http://localhost:9090'] # additional CORS allowed origins
verbose = False # enables copious socket.io logging
additional_allowed_origins = [
"http://localhost:5173"
] # additional CORS allowed origins
model = "stable-diffusion-1.4"
"""
END USER CONFIG
@ -47,26 +53,23 @@ SERVER SETUP
# fix missing mimetypes on windows due to registry wonkiness
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/')
app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/")
app.config['OUTPUTS_FOLDER'] = "../outputs"
app.config["OUTPUTS_FOLDER"] = "../outputs"
@app.route('/outputs/<path:filename>')
@app.route("/outputs/<path:filename>")
def outputs(filename):
return send_from_directory(
app.config['OUTPUTS_FOLDER'],
filename
)
return send_from_directory(app.config["OUTPUTS_FOLDER"], filename)
@app.route("/", defaults={'path': ''})
@app.route("/", defaults={"path": ""})
def serve(path):
return send_from_directory(app.static_folder, 'index.html')
return send_from_directory(app.static_folder, "index.html")
logger = True if verbose else False
@ -78,12 +81,12 @@ max_http_buffer_size = 10000000
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
socketio = SocketIO(
app,
logger=logger,
engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins,
)
app,
logger=logger,
engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins,
)
"""
@ -104,29 +107,31 @@ canceled = Event()
# reduce logging outputs to error
transformers.logging.set_verbosity_error()
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model
model = Generate()
model.load_model()
generate = Generate(model)
generate.load_model()
# location for "finished" images
result_path = os.path.join(output_dir, 'img-samples/')
result_path = os.path.join(output_dir, "img-samples/")
# temporary path for intermediates
intermediate_path = os.path.join(result_path, 'intermediates/')
intermediate_path = os.path.join(result_path, "intermediates/")
# path for user-uploaded init images and masks
init_path = os.path.join(result_path, 'init-images/')
mask_path = os.path.join(result_path, 'mask-images/')
init_image_path = os.path.join(result_path, "init-images/")
mask_image_path = os.path.join(result_path, "mask-images/")
# txt log
log_path = os.path.join(result_path, 'dream_log.txt')
log_path = os.path.join(result_path, "dream_log.txt")
# make all output paths
[os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_path, mask_path]]
[
os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_image_path, mask_image_path]
]
"""
@ -139,126 +144,219 @@ SOCKET.IO LISTENERS
"""
@socketio.on('requestAllImages')
@socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(f">> System config requested")
config = get_system_config()
socketio.emit("systemConfig", config)
@socketio.on("requestAllImages")
def handle_request_all_images():
print(f'>> All images requested')
parser = create_cmd_parser()
print(f">> All images requested")
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
paths.sort(key=lambda x: os.path.getmtime(x))
image_array = []
for path in paths:
# image = Image.open(path)
all_metadata = retrieve_metadata(path)
if 'Dream' in all_metadata and not all_metadata['sd-metadata']:
metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream'])))
else:
metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata})
return make_response("OK", data=image_array)
metadata = retrieve_metadata(path)
image_array.append({"url": path, "metadata": metadata["sd-metadata"]})
socketio.emit("galleryImages", {"images": image_array})
eventlet.sleep(0)
@socketio.on('generateImage')
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters):
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}')
generate_images(
generation_parameters,
esrgan_parameters,
gfpgan_parameters
@socketio.on("generateImage")
def handle_generate_image_event(
generation_parameters, esrgan_parameters, gfpgan_parameters
):
print(
f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}"
)
return make_response("OK")
generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
@socketio.on('runESRGAN')
@socketio.on("runESRGAN")
def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}')
print(
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Upscaling"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
upsampler_scale=esrgan_parameters['upscale'][0],
strength=esrgan_parameters['upscale'][1],
seed=seed
upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters["upscale"][1],
seed=seed,
)
esrgan_parameters['seed'] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan')
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
esrgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=esrgan_parameters,
original_image_path=original_image["url"],
type="esrgan",
)
command = parameters_to_command(esrgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="esrgan")
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'esrgan', 'uuid': original_image['uuid'],'metadata': esrgan_parameters})
"esrganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on('runGFPGAN')
@socketio.on("runGFPGAN")
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}')
print(
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed'
seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Fixing faces"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['gfpgan_strength'],
strength=gfpgan_parameters["gfpgan_strength"],
seed=seed,
upsampler_scale=1
upsampler_scale=1,
)
gfpgan_parameters['seed'] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan')
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
gfpgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=gfpgan_parameters,
original_image_path=original_image["url"],
type="gfpgan",
)
command = parameters_to_command(gfpgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="gfpgan")
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'gfpgan', 'uuid': original_image['uuid'],'metadata': gfpgan_parameters})
"gfpganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on('cancel')
@socketio.on("cancel")
def handle_cancel():
print(f'>> Cancel processing requested')
print(f">> Cancel processing requested")
canceled.set()
return make_response("OK")
socketio.emit("processingCanceled")
# TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage')
def handle_delete_image(path):
@socketio.on("deleteImage")
def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"')
Path(path).unlink()
return make_response("OK")
send2trash(path)
socketio.emit("imageDeleted", {"url": path, "uuid": uuid})
# TODO: I think this needs a safety mechanism.
@socketio.on('uploadInitialImage')
@socketio.on("uploadInitialImage")
def handle_upload_initial_image(bytes, name):
print(f'>> Init image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(init_path, name)
name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""})
# TODO: I think this needs a safety mechanism.
@socketio.on('uploadMaskImage')
@socketio.on("uploadMaskImage")
def handle_upload_mask_image(bytes, name):
print(f'>> Mask image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}'
file_path = os.path.join(mask_path, name)
name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
return make_response("OK", data=file_path)
socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""})
"""
@ -266,114 +364,343 @@ END SOCKET.IO LISTENERS
"""
"""
ADDITIONAL FUNCTIONS
"""
def get_system_config():
return {
"model": "stable diffusion",
"model_id": model,
"model_hash": generate.model_hash,
"app_id": APP_ID,
"app_version": APP_VERSION,
}
def parameters_to_post_processed_image_metadata(parameters, original_image_path, type):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
orig_hash = calculate_init_img_hash(original_image_path)
image = {"orig_path": original_image_path, "orig_hash": orig_hash}
if type == "esrgan":
image["type"] = "esrgan"
image["scale"] = parameters["upscale"][0]
image["strength"] = parameters["upscale"][1]
elif type == "gfpgan":
image["type"] = "gfpgan"
image["strength"] = parameters["gfpgan_strength"]
else:
raise TypeError(f"Invalid type: {type}")
metadata["image"] = image
return metadata
def parameters_to_generated_image_metadata(parameters):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = [
"type",
"postprocessing",
"sampler",
"prompt",
"seed",
"variations",
"steps",
"cfg_scale",
"step_number",
"width",
"height",
"extra",
"seamless",
]
rfc_dict = {}
for item in parameters.items():
key, value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
postprocessing = []
# 'postprocessing' is either null or an
if "gfpgan_strength" in parameters:
postprocessing.append(
{"type": "gfpgan", "strength": float(parameters["gfpgan_strength"])}
)
if "upscale" in parameters:
postprocessing.append(
{
"type": "esrgan",
"scale": int(parameters["upscale"][0]),
"strength": float(parameters["upscale"][1]),
}
)
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"])
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]
]
rfc_dict["variations"] = variations
if "init_img" in parameters:
rfc_dict["type"] = "img2img"
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"])
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
if "init_mask" in parameters:
rfc_dict["mask_hash"] = calculate_init_img_hash(
parameters["init_mask"]
) # TODO: Noncompliant
rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant
else:
rfc_dict["type"] = "txt2img"
metadata["image"] = rfc_dict
return metadata
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
return name
def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n'
with open(log_path, 'a', encoding='utf-8') as file:
message = f"{message}\n"
with open(log_path, "a", encoding="utf-8") as file:
file.writelines(message)
def make_response(status, message=None, data=None):
response = {'status': status}
if message is not None:
response['message'] = message
if data is not None:
response['data'] = data
return response
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False):
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed'
def save_image(
image, command, metadata, output_dir, step_index=None, postprocessing=False
):
pngwriter = PngWriter(output_dir)
prefix = pngwriter.unique_prefix()
filename = f'{prefix}.{seed}'
seed = "unknown_seed"
if "image" in metadata:
if "seed" in metadata["image"]:
seed = metadata["image"]["seed"]
filename = f"{prefix}.{seed}"
if step_index:
filename += f'.{step_index}'
filename += f".{step_index}"
if postprocessing:
filename += f'.postprocessed'
filename += f".postprocessed"
filename += '.png'
filename += ".png"
command = parameters_to_command(parameters)
path = pngwriter.save_image_and_prompt_to_png(image, command, metadata=parameters, name=filename)
path = pngwriter.save_image_and_prompt_to_png(
image=image, dream_prompt=command, metadata=metadata, name=filename
)
return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear()
step_index = 1
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if "init_img" in generation_parameters:
filename = os.path.basename(generation_parameters["init_img"])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_img"] = new_path
if "init_mask" in generation_parameters:
filename = os.path.basename(generation_parameters["init_mask"])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_mask"] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
has_init_image="init_img" in generation_parameters,
)
progress = {
"currentStep": 1,
"totalSteps": totalSteps,
"currentIteration": 1,
"totalIterations": generation_parameters["iterations"],
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_progress(sample, step):
if canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1:
image = model.sample_to_image(sample)
path = save_image(image, generation_parameters, intermediate_path, step_index)
nonlocal progress
progress["currentStep"] = step + 1
progress["currentStatus"] = "Generating"
progress["currentStatusHasSteps"] = True
if (
generation_parameters["progress_images"]
and step % 5 == 0
and step < generation_parameters["steps"] - 1
):
image = generate.sample_to_image(sample)
path = save_image(
image, generation_parameters, intermediate_path, step_index
)
step_index += 1
socketio.emit('intermediateResult', {
'url': os.path.relpath(path), 'metadata': generation_parameters})
socketio.emit('progress', {'step': step + 1})
socketio.emit(
"intermediateResult",
{"url": os.path.relpath(path), "metadata": generation_parameters},
)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_done(image, seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
progress["currentStatus"] = "Generation complete"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
all_parameters = generation_parameters
postprocessing = False
if esrgan_parameters:
progress["currentStatus"] = "Upscaling"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = real_esrgan_upscale(
image=image,
strength=esrgan_parameters['strength'],
upsampler_scale=esrgan_parameters['level'],
seed=seed
strength=esrgan_parameters["strength"],
upsampler_scale=esrgan_parameters["level"],
seed=seed,
)
postprocessing = True
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']]
all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters["strength"],
]
if gfpgan_parameters:
progress["currentStatus"] = "Fixing faces"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = run_gfpgan(
image=image,
strength=gfpgan_parameters['strength'],
strength=gfpgan_parameters["strength"],
seed=seed,
upsampler_scale=1,
)
postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength']
all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
all_parameters['seed'] = seed
all_parameters["seed"] = seed
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing)
metadata = parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters)
print(f'Image generated: "{path}"')
path = save_image(
image, command, metadata, result_path, postprocessing=postprocessing
)
print(f'>> Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}')
if progress["totalIterations"] > progress["currentIteration"]:
progress["currentStep"] = 1
progress["currentIteration"] += 1
progress["currentStatus"] = "Iteration finished"
progress["currentStatusHasSteps"] = False
else:
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["currentStatus"] = "Finished"
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
'result', {'url': os.path.relpath(path), 'type': 'generation', 'metadata': all_parameters})
"generationResult",
{"url": os.path.relpath(path), "metadata": metadata},
)
eventlet.sleep(0)
try:
model.prompt2image(
generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done
image_callback=image_done,
)
except KeyboardInterrupt:
@ -381,7 +708,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException:
pass
except Exception as e:
socketio.emit('error', (str(e)))
socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
@ -392,6 +719,6 @@ END ADDITIONAL FUNCTIONS
"""
if __name__ == '__main__':
print(f'Starting server at http://{host}:{port}')
if __name__ == "__main__":
print(f">> Starting server at http://{host}:{port}")
socketio.run(app, host=host, port=port)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -59,9 +59,7 @@ Once the model is trained, specify the trained .pt or .bin file when starting
dream using
```bash
python3 ./scripts/dream.py \
--embedding_path /path/to/embedding.pt \
--full_precision
python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt
```
Then, to utilize your subject at the dream prompt

View File

@ -97,6 +97,11 @@ You wil need one of the following:
```
## :octicons-log-16: Latest Changes
### vNEXT <small>(TODO 2022)</small>
- Deprecated `--full_precision` / `-F`. Simply omit it and `dream.py` will auto
configure. To switch away from auto use the new flag like `--precision=float32`.
### v1.14 <small>(11 September 2022)</small>
- Memory optimizations for small-RAM cards. 512x512 now possible on 4 GB GPUs.

View File

@ -106,7 +106,6 @@ PATH_TO_CKPT="$HOME/Downloads" # (1)!
ln -s "$PATH_TO_CKPT/sd-v1-4.ckpt" \
models/ldm/stable-diffusion-v1/model.ckpt
```
1. or wherever you saved sd-v1-4.ckpt
@ -548,5 +547,3 @@ Abort trap: 6
warnings.warn('resource_tracker: There appear to be %d '
```
Macs do not support `autocast/mixed-precision`, so you need to supply
`--full_precision` to use float32 everywhere.

View File

@ -32,6 +32,7 @@ dependencies:
- omegaconf==2.1.1
- onnx==1.12.0
- onnxruntime==1.12.1
- protobuf==3.20.1
- pudb==2022.1
- pytorch-lightning==1.6.5
- scipy==1.9.1
@ -48,6 +49,7 @@ dependencies:
- opencv-python==4.6.0
- protobuf==3.20.1
- realesrgan==0.2.5.0
- send2trash==1.8.0
- test-tube==0.7.5
- transformers==4.21.2
- torch-fidelity==0.3.0

View File

@ -20,6 +20,7 @@ dependencies:
- realesrgan==0.2.5.0
- test-tube>=0.7.5
- streamlit==1.12.0
- send2trash==1.8.0
- pillow==9.2.0
- einops==0.3.0
- torch-fidelity==0.3.0

View File

@ -1,85 +1,37 @@
# Stable Diffusion Web UI
Demo at https://peaceful-otter-7a427f.netlify.app/ (not connected to back end)
## Run
much of this readme is just notes for myself during dev work
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
numpy rand: 0 to 4294967295
## Evironment
## Test and Build
Install [node](https://nodejs.org/en/download/) (includes npm) and optionally
[yarn](https://yarnpkg.com/getting-started/install).
from `frontend/`:
From `frontend/` run `npm install` / `yarn install` to install the frontend packages.
- `yarn dev` runs `tsc-watch`, which runs `vite build` on successful `tsc` transpilation
## Dev
from `.`:
1. From `frontend/`, run `npm dev` / `yarn dev` to start the dev server.
2. Note the address it starts up on (probably `http://localhost:5173/`).
3. Edit `backend/server.py`'s `additional_allowed_origins` to include this address, e.g.
`additional_allowed_origins = ['http://localhost:5173']`.
4. Leaving the dev server running, open a new terminal and go to the project root.
5. Run `python backend/server.py`.
6. Navigate to the dev server address e.g. `http://localhost:5173/`.
- `python backend/server.py` serves both frontend and backend at http://localhost:9090
To build for dev: `npm build-dev` / `yarn build-dev`
## API
`backend/server.py` serves the UI and provides a [socket.io](https://github.com/socketio/socket.io) API via [flask-socketio](https://github.com/miguelgrinberg/flask-socketio).
### Server Listeners
The server listens for these socket.io events:
`cancel`
- Cancels in-progress image generation
- Returns ack only
`generateImage`
- Accepts object of image parameters
- Generates an image
- Returns ack only (image generation function sends progress and result via separate events)
`deleteImage`
- Accepts file path to image
- Deletes image
- Returns ack only
`deleteAllImages` WIP
- Deletes all images in `outputs/`
- Returns ack only
`requestAllImages`
- Returns array of all images in `outputs/`
`requestCapabilities` WIP
- Returns capabilities of server (torch device, GFPGAN and ESRGAN availability, ???)
`sendImage` WIP
- Accepts a File and attributes
- Saves image
- Used to save init images which are not generated images
### Server Emitters
`progress`
- Emitted during each step in generation
- Sends a number from 0 to 1 representing percentage of steps completed
`result` WIP
- Emitted when an image generation has completed
- Sends a object:
```
{
url: relative_file_path,
metadata: image_metadata_object
}
```
To build for production: `npm build` / `yarn build`
## TODO
- Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on `framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on this. Need to check in on this issue periodically.
- Search repo for "TODO"
- My one gripe with Chakra: no way to disable all animations right now and drop the dependence on
`framer-motion`. I would prefer to save the ~30kb on bundle and have zero animations. This is on
the Chakra roadmap. See https://github.com/chakra-ui/chakra-ui/pull/6368 for last discussion on
this. Need to check in on this issue periodically.
- Mobile friendly layout
- Proper image gallery/viewer/manager
- Help tooltips and such

694
frontend/dist/assets/index.727a397b.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.cc5cde43.js"></script>
<script type="module" crossorigin src="/assets/index.727a397b.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head>
<body>

View File

@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title>
<title>InvokeAI Stable Diffusion Dream Server</title>
</head>
<body>
<div id="root"></div>

View File

@ -1,16 +1,16 @@
{
"name": "sdui",
"name": "invoke-ai-ui",
"private": true,
"version": "0.0.0",
"version": "0.0.1",
"type": "module",
"scripts": {
"dev": "tsc-watch --onSuccess 'yarn run vite build -m development'",
"hmr": "vite dev",
"dev": "vite dev",
"build": "tsc && vite build",
"build-dev": "tsc && vite build -m development",
"preview": "vite preview"
},
"dependencies": {
"@chakra-ui/icons": "^2.0.10",
"@chakra-ui/react": "^2.3.1",
"@emotion/react": "^11.10.4",
"@emotion/styled": "^11.10.4",

View File

@ -1,60 +0,0 @@
import { Grid, GridItem } from '@chakra-ui/react';
import CurrentImage from './features/gallery/CurrentImage';
import LogViewer from './features/system/LogViewer';
import PromptInput from './features/sd/PromptInput';
import ProgressBar from './features/header/ProgressBar';
import { useEffect } from 'react';
import { useAppDispatch } from './app/hooks';
import { requestAllImages } from './app/socketio';
import ProcessButtons from './features/sd/ProcessButtons';
import ImageRoll from './features/gallery/ImageRoll';
import SiteHeader from './features/header/SiteHeader';
import OptionsAccordion from './features/sd/OptionsAccordion';
const App = () => {
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(requestAllImages());
}, [dispatch]);
return (
<>
<Grid
width='100vw'
height='100vh'
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl='2' area={'menu'} overflowY='scroll'>
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImage />
</GridItem>
<GridItem pr='2' area={'imageRoll'} overflowY='scroll'>
<ImageRoll />
</GridItem>
</Grid>
<LogViewer />
</>
);
};
export default App;

69
frontend/src/app/App.tsx Normal file
View File

@ -0,0 +1,69 @@
import { Grid, GridItem } from '@chakra-ui/react';
import { useEffect, useState } from 'react';
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from '../features/system/ProgressBar';
import SiteHeader from '../features/system/SiteHeader';
import OptionsAccordion from '../features/options/OptionsAccordion';
import ProcessButtons from '../features/options/ProcessButtons';
import PromptInput from '../features/options/PromptInput';
import LogViewer from '../features/system/LogViewer';
import Loading from '../Loading';
import { useAppDispatch } from './store';
import { requestAllImages, requestSystemConfig } from './socketio/actions';
const App = () => {
const dispatch = useAppDispatch();
const [isReady, setIsReady] = useState<boolean>(false);
// Load images from the gallery once
useEffect(() => {
dispatch(requestAllImages());
dispatch(requestSystemConfig());
setIsReady(true);
}, [dispatch]);
return isReady ? (
<>
<Grid
width="100vw"
height="100vh"
templateAreas={`
"header header header header"
"progressBar progressBar progressBar progressBar"
"menu prompt processButtons imageRoll"
"menu currentImage currentImage imageRoll"`}
gridTemplateRows={'36px 10px 100px auto'}
gridTemplateColumns={'350px auto 100px 388px'}
gap={2}
>
<GridItem area={'header'} pt={1}>
<SiteHeader />
</GridItem>
<GridItem area={'progressBar'}>
<ProgressBar />
</GridItem>
<GridItem pl="2" area={'menu'} overflowY="scroll">
<OptionsAccordion />
</GridItem>
<GridItem area={'prompt'}>
<PromptInput />
</GridItem>
<GridItem area={'processButtons'}>
<ProcessButtons />
</GridItem>
<GridItem area={'currentImage'}>
<CurrentImageDisplay />
</GridItem>
<GridItem pr="2" area={'imageRoll'} overflowY="scroll">
<ImageGallery />
</GridItem>
</Grid>
<LogViewer />
</>
) : (
<Loading />
);
};
export default App;

View File

@ -2,52 +2,52 @@
// Valid samplers
export const SAMPLERS: Array<string> = [
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_euler',
'k_euler_a',
'k_heun',
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024,
];
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
// Internal to human-readable parameters
export const PARAMETERS: { [key: string]: string } = {
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
prompt: 'Prompt',
iterations: 'Iterations',
steps: 'Steps',
cfgScale: 'CFG Scale',
height: 'Height',
width: 'Width',
sampler: 'Sampler',
seed: 'Seed',
img2imgStrength: 'img2img Strength',
gfpganStrength: 'GFPGAN Strength',
upscalingLevel: 'Upscaling Level',
upscalingStrength: 'Upscaling Strength',
initialImagePath: 'Initial Image',
maskPath: 'Initial Image Mask',
shouldFitToWidthHeight: 'Fit Initial Image',
seamless: 'Seamless Tiling',
};
export const NUMPY_RAND_MIN = 0;

View File

@ -1,7 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import type { RootState, AppDispatch } from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

170
frontend/src/app/invokeai.d.ts vendored Normal file
View File

@ -0,0 +1,170 @@
/**
* Types for images, the things they are made of, and the things
* they make up.
*
* Generated images are txt2img and img2img images. They may have
* had additional postprocessing done on them when they were first
* generated.
*
* Postprocessed images are images which were not generated here
* but only postprocessed by the app. They only get postprocessing
* metadata and have a different image type, e.g. 'esrgan' or
* 'gfpgan'.
*/
/**
* TODO:
* Once an image has been generated, if it is postprocessed again,
* additional postprocessing steps are added to its postprocessing
* array.
*
* TODO: Better documentation of types.
*/
export declare type PromptItem = {
prompt: string;
weight: number;
};
export declare type Prompt = Array<PromptItem>;
export declare type SeedWeightPair = {
seed: number;
weight: number;
};
export declare type SeedWeights = Array<SeedWeightPair>;
// All generated images contain these metadata.
export declare type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
sampler:
| 'ddim'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'plms';
prompt: Prompt;
seed: number;
variations: SeedWeights;
steps: number;
cfg_scale: number;
width: number;
height: number;
seamless: boolean;
extra: null | Record<string, never>; // Pending development of RFC #266
};
// txt2img and img2img images have some unique attributes.
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
type: 'txt2img';
};
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
fit: boolean;
init_image_path: string;
mask_image_path?: string;
};
// Superset of generated image metadata types.
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export declare type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
};
export declare type GFPGANMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan';
strength: number;
};
// Superset of all postprocessed image metadata types..
export declare type PostProcessedImageMetadata =
| ESRGANMetadata
| GFPGANMetadata;
// Metadata includes the system config and image metadata.
export declare type Metadata = SystemConfig & {
image: GeneratedImageMetadata | PostProcessedImageMetadata;
};
// An Image has a UUID, url (path?) and Metadata.
export declare type Image = {
uuid: string;
url: string;
metadata: Metadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
};
/**
* Types related to the system status.
*/
// This represents the processing status of the backend.
export declare type SystemStatus = {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
};
export declare type SystemConfig = {
model: string;
model_id: string;
model_hash: string;
app_id: string;
app_version: string;
};
/**
* These types type data received from the server via socketio.
*/
export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = {
url: string;
metadata: Metadata;
};
export declare type ErrorResponse = {
message: string;
additionalData?: string;
};
export declare type GalleryImagesResponse = {
images: Array<{ url: string; metadata: Metadata }>;
};
export declare type ImageUrlAndUuidResponse = {
uuid: string;
url: string;
};
export declare type ImageUrlResponse = {
url: string;
};

View File

@ -1,182 +0,0 @@
import { SDState } from '../features/sd/sdSlice';
import randomInt from '../features/sd/util/randomInt';
import {
seedWeightsToString,
stringToSeedWeights,
} from '../features/sd/util/seedWeightPairs';
import { SystemState } from '../features/system/systemSlice';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from './constants';
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
export const frontendToBackendParameters = (
sdState: SDState,
systemState: SystemState
): { [key: string]: any } => {
const {
prompt,
iterations,
steps,
cfgScale,
height,
width,
sampler,
seed,
seamless,
shouldUseInitImage,
img2imgStrength,
initialImagePath,
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
upscalingStrength,
shouldRunGFPGAN,
gfpganStrength,
shouldRandomizeSeed,
} = sdState;
const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = {
prompt,
iterations,
steps,
cfg_scale: cfgScale,
height,
width,
sampler_name: sampler,
seed,
seamless,
progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variantAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeights(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
};
export const backendToFrontendParameters = (parameters: {
[key: string]: any;
}) => {
const {
prompt,
iterations,
steps,
cfg_scale,
height,
width,
sampler_name,
seed,
seamless,
progress_images,
variation_amount,
with_variations,
gfpgan_strength,
upscale,
init_img,
init_mask,
strength,
} = parameters;
const sd: { [key: string]: any } = {
shouldDisplayInProgress: progress_images,
// init
shouldGenerateVariations: false,
shouldRunESRGAN: false,
shouldRunGFPGAN: false,
initialImagePath: '',
maskPath: '',
};
if (variation_amount > 0) {
sd.shouldGenerateVariations = true;
sd.variantAmount = variation_amount;
if (with_variations) {
sd.seedWeights = seedWeightsToString(with_variations);
}
}
if (gfpgan_strength > 0) {
sd.shouldRunGFPGAN = true;
sd.gfpganStrength = gfpgan_strength;
}
if (upscale) {
sd.shouldRunESRGAN = true;
sd.upscalingLevel = upscale[0];
sd.upscalingStrength = upscale[1];
}
if (init_img) {
sd.shouldUseInitImage = true
sd.initialImagePath = init_img;
sd.strength = strength;
if (init_mask) {
sd.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
sd.prompt = prompt;
sd.iterations = iterations;
sd.steps = steps;
sd.cfgScale = cfg_scale;
sd.height = height;
sd.width = width;
sd.sampler = sampler_name;
sd.seed = seed;
sd.seamless = seamless;
}
return sd;
};

View File

@ -1,393 +0,0 @@
import { createAction, Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import {
addImage,
clearIntermediateImage,
removeImage,
SDImage,
SDMetadata,
setGalleryImages,
setIntermediateImage,
} from '../features/gallery/gallerySlice';
import {
addLogEntry,
setCurrentStep,
setIsConnected,
setIsProcessing,
} from '../features/system/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { setInitialImagePath, setMaskPath } from '../features/sd/sdSlice';
import {
backendToFrontendParameters,
frontendToBackendParameters,
} from './parameterTranslation';
export interface SocketIOResponse {
status: 'OK' | 'ERROR';
message?: string;
data?: any;
}
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { dispatch, getState } = store;
if (!areListenersSet) {
// CONNECT
socketio.on('connect', () => {
try {
dispatch(setIsConnected(true));
} catch (e) {
console.error(e);
}
});
// DISCONNECT
socketio.on('disconnect', () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(addLogEntry(`Disconnected from server`));
} catch (e) {
console.error(e);
}
});
// PROCESSING RESULT
socketio.on(
'result',
(data: {
url: string;
type: 'generation' | 'esrgan' | 'gfpgan';
uuid?: string;
metadata: { [key: string]: any };
}) => {
try {
const newUuid = uuidv4();
const { type, url, uuid, metadata } = data;
switch (type) {
case 'generation': {
const translatedMetadata =
backendToFrontendParameters(metadata);
dispatch(
addImage({
uuid: newUuid,
url,
metadata: translatedMetadata,
})
);
dispatch(
addLogEntry(`Image generated: ${url}`)
);
break;
}
case 'esrgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel =
metadata.upscale[0];
newMetadata.upscalingStrength =
metadata.upscale[1];
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`ESRGAN upscaled: ${url}`)
);
break;
}
case 'gfpgan': {
const originalImage =
getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
const newMetadata = {
...originalImage.metadata,
};
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength =
metadata.gfpgan_strength;
dispatch(
addImage({
uuid: newUuid,
url,
metadata: newMetadata,
})
);
dispatch(
addLogEntry(`GFPGAN fixed faces: ${url}`)
);
break;
}
}
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
}
);
// PROGRESS UPDATE
socketio.on('progress', (data: { step: number }) => {
try {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(data.step));
} catch (e) {
console.error(e);
}
});
// INTERMEDIATE IMAGE
socketio.on(
'intermediateResult',
(data: { url: string; metadata: SDMetadata }) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry(`Intermediate image generated: ${url}`)
);
} catch (e) {
console.error(e);
}
}
);
// ERROR FROM BACKEND
socketio.on('error', (message) => {
try {
dispatch(addLogEntry(`Server error: ${message}`));
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
});
areListenersSet = true;
}
// HANDLE ACTIONS
switch (action.type) {
// GENERATE IMAGE
case 'socketio/generateImage': {
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const {
generationParameters,
esrganParameters,
gfpganParameters,
} = frontendToBackendParameters(
getState().sd,
getState().system
);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry(
`Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`
)
);
break;
}
// RUN ESRGAN (UPSCALING)
case 'socketio/runESRGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { upscalingLevel, upscalingStrength } = getState().sd;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry(
`ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`
)
);
break;
}
// RUN GFPGAN (FIX FACES)
case 'socketio/runGFPGAN': {
const imageToProcess = action.payload;
dispatch(setIsProcessing(true));
dispatch(setCurrentStep(-1));
const { gfpganStrength } = getState().sd;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry(
`GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`
)
);
break;
}
// DELETE IMAGE
case 'socketio/deleteImage': {
const imageToDelete = action.payload;
const { url } = imageToDelete;
socketio.emit(
'deleteImage',
url,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(removeImage(imageToDelete));
dispatch(addLogEntry(`Image deleted: ${url}`));
}
}
);
break;
}
// GET ALL IMAGES FOR GALLERY
case 'socketio/requestAllImages': {
socketio.emit(
'requestAllImages',
(response: SocketIOResponse) => {
dispatch(setGalleryImages(response.data));
dispatch(
addLogEntry(`Loaded ${response.data.length} images`)
);
}
);
break;
}
// CANCEL PROCESSING
case 'socketio/cancelProcessing': {
socketio.emit('cancel', (response: SocketIOResponse) => {
const { intermediateImage } = getState().gallery;
if (response.status === 'OK') {
dispatch(setIsProcessing(false));
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry(
`Intermediate image saved: ${intermediateImage.url}`
)
);
dispatch(clearIntermediateImage());
}
dispatch(addLogEntry(`Processing canceled`));
}
});
break;
}
// UPLOAD INITIAL IMAGE
case 'socketio/uploadInitialImage': {
const file = action.payload;
socketio.emit(
'uploadInitialImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setInitialImagePath(response.data));
dispatch(
addLogEntry(
`Initial image uploaded: ${response.data}`
)
);
}
}
);
break;
}
// UPLOAD MASK IMAGE
case 'socketio/uploadMaskImage': {
const file = action.payload;
socketio.emit(
'uploadMaskImage',
file,
file.name,
(response: SocketIOResponse) => {
if (response.status === 'OK') {
dispatch(setMaskPath(response.data));
dispatch(
addLogEntry(
`Mask image uploaded: ${response.data}`
)
);
}
}
);
break;
}
}
next(action);
};
return middleware;
};
// Actions to be used by app
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');

View File

@ -0,0 +1,26 @@
import { createAction } from '@reduxjs/toolkit';
import * as InvokeAI from '../invokeai';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runGFPGAN = createAction<InvokeAI.Image>('socketio/runGFPGAN');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage'
);
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
export const requestSystemConfig = createAction<undefined>('socketio/requestSystemConfig');

View File

@ -0,0 +1,104 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import {
addLogEntry,
setIsProcessing,
} from '../../features/system/systemSlice';
import * as InvokeAI from '../invokeai';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: () => {
dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().options, getState().system);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
gfpganParameters
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...gfpganParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().options;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
socketio.emit('runESRGAN', imageToProcess, esrganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunGFPGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().options;
const gfpganParameters = {
gfpgan_strength: gfpganStrength,
};
socketio.emit('runGFPGAN', imageToProcess, gfpganParameters);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `GFPGAN fix faces requested: ${JSON.stringify({
file: imageToProcess.url,
...gfpganParameters,
})}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
},
emitRequestAllImages: () => {
socketio.emit('requestAllImages');
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitUploadInitialImage: (file: File) => {
socketio.emit('uploadInitialImage', file, file.name);
},
emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name);
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig')
}
};
};
export default makeSocketIOEmitters;

View File

@ -0,0 +1,300 @@
import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat';
import * as InvokeAI from '../invokeai';
import {
addLogEntry,
setIsConnected,
setIsProcessing,
setSystemStatus,
setCurrentStatus,
setSystemConfig,
} from '../../features/system/systemSlice';
import {
addImage,
clearIntermediateImage,
removeImage,
setGalleryImages,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import {
setInitialImagePath,
setMaskPath,
} from '../../features/options/optionsSlice';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, any>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus('Disconnected'));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
const newUuid = uuidv4();
dispatch(
addImage({
uuid: newUuid,
url,
metadata: metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
try {
const uuid = uuidv4();
const { url, metadata } = data;
dispatch(
setIntermediateImage({
uuid,
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onESRGANResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
dispatch(
addImage({
uuid: uuidv4(),
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Upscaled: ${url}`,
})
);
dispatch(setIsProcessing(false));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'gfpganResult' event.
*/
onGFPGANResult: (data: InvokeAI.ImageResultResponse) => {
try {
const { url, metadata } = data;
dispatch(
addImage({
uuid: uuidv4(),
url,
metadata,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Fixed faces: ${url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: InvokeAI.SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: InvokeAI.ErrorResponse) => {
const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(setIsProcessing(false));
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
const { images } = data;
const preparedImages = images.map((image): InvokeAI.Image => {
const { url, metadata } = image;
return {
uuid: uuidv4(),
url,
metadata,
};
});
dispatch(setGalleryImages(preparedImages));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(setIsProcessing(false));
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
dispatch(clearIntermediateImage());
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.
*/
onMaskImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data;
dispatch(setMaskPath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Mask image uploaded: ${url}`,
})
);
},
onSystemConfig: (data: InvokeAI.SystemConfig) => {
dispatch(setSystemConfig(data));
},
};
};
export default makeSocketIOListeners;

View File

@ -0,0 +1,173 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters';
import * as InvokeAI from '../invokeai';
/**
* Creates a socketio middleware to handle communication with server.
*
* Special `socketio/actionName` actions are created in actions.ts and
* exported for use by the application, which treats them like any old
* action, using `dispatch` to dispatch them.
*
* These actions are intercepted here, where `socketio.emit()` calls are
* made on their behalf - see `emitters.ts`. The emitter functions
* are the outbound communication to the server.
*
* Listeners are also established here - see `listeners.ts`. The listener
* functions receive communication from the server and usually dispatch
* some new action to handle whatever data was sent from the server.
*/
export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href);
const socketio = io(`http://${hostname}:9090`);
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onESRGANResult,
onGFPGANResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onInitialImageUploaded,
onMaskImageUploaded,
onSystemConfig,
} = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunGFPGAN,
emitDeleteImage,
emitRequestAllImages,
emitCancelProcessing,
emitUploadInitialImage,
emitUploadMaskImage,
emitRequestSystemConfig,
} = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
onGenerationResult(data)
);
socketio.on('esrganResult', (data: InvokeAI.ImageResultResponse) =>
onESRGANResult(data)
);
socketio.on('gfpganResult', (data: InvokeAI.ImageResultResponse) =>
onGFPGANResult(data)
);
socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
onIntermediateResult(data)
);
socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
onProgressUpdate(data)
);
socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
onGalleryImages(data)
);
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onInitialImageUploaded(data);
});
socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onMaskImageUploaded(data);
});
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
onSystemConfig(data);
});
areListenersSet = true;
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage();
break;
}
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
case 'socketio/runGFPGAN': {
emitRunGFPGAN(action.payload);
break;
}
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
case 'socketio/requestAllImages': {
emitRequestAllImages();
break;
}
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
case 'socketio/uploadInitialImage': {
emitUploadInitialImage(action.payload);
break;
}
case 'socketio/uploadMaskImage': {
emitUploadMaskImage(action.payload);
break;
}
case 'socketio/requestSystemConfig': {
emitRequestSystemConfig();
break;
}
}
next(action);
};
return middleware;
};

View File

@ -1,53 +1,78 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { useDispatch, useSelector } from 'react-redux';
import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice';
import optionsReducer from '../features/options/optionsSlice';
import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio';
import { socketioMiddleware } from './socketio/middleware';
const reducers = combineReducers({
sd: sdReducer,
gallery: galleryReducer,
system: systemReducer,
});
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be blacklisted in redux-persist.
*
* The necesssary nested persistors with blacklists are configured below.
*
* TODO: Do we blacklist initialImagePath? If the image is deleted from disk we get an
* ugly 404. But if we blacklist it, then this is a valuable parameter that is lost
* on reload. Need to figure out a good way to handle this.
*/
const persistConfig = {
const rootPersistConfig = {
key: 'root',
storage,
blacklist: ['gallery', 'system'],
};
const persistedReducer = persistReducer(persistConfig, reducers);
const systemPersistConfig = {
key: 'system',
storage,
blacklist: [
'isConnected',
'isProcessing',
'currentStep',
'socketId',
'isESRGANAvailable',
'isGFPGANAvailable',
'currentStep',
'totalSteps',
'currentIteration',
'totalIterations',
'currentStatus',
],
};
/*
The frontend needs to be distributed as a production build, so
we cannot reasonably ask users to edit the JS and specify the
host and port on which the socket.io server will run.
The solution is to allow server script to be run with arguments
(or just edited) providing the host and port. Then, the server
serves a route `/socketio_config` which responds with the host
and port.
When the frontend loads, it synchronously requests that route
and thus gets the host and port. This requires a suspicious
fetch somewhere, and the store setup seems like as good a place
as any to make this fetch request.
*/
const reducers = combineReducers({
options: optionsReducer,
gallery: galleryReducer,
system: persistReducer(systemPersistConfig, systemReducer),
});
const persistedReducer = persistReducer(rootPersistConfig, reducers);
// Continue with store setup
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// redux-persist sometimes needs to have a function in redux, need to disable this check
// redux-persist sometimes needs to temporarily put a function in redux state, need to disable this check
serializableCheck: false,
}).concat(socketioMiddleware()),
});
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<typeof store.getState>;
// Inferred type: {posts: PostsState, comments: CommentsState, users: UsersState}
export type AppDispatch = typeof store.dispatch;
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -33,5 +33,20 @@ export const theme = extendTheme({
fontWeight: 'light',
},
},
Button: {
variants: {
imageHoverIconButton: (props: StyleFunctionProps) => ({
bg: props.colorMode === 'dark' ? 'blackAlpha.700' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.700' : 'blackAlpha.700',
_hover: {
bg:
props.colorMode === 'dark' ? 'blackAlpha.800' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.900' : 'blackAlpha.900',
},
}),
},
},
},
});

View File

@ -0,0 +1,21 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
/**
* Reusable customized button component. Originally was more customized - now probably unecessary.
*
* TODO: Get rid of this.
*/
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -16,6 +16,9 @@ interface Props extends NumberInputProps {
width?: string | number;
}
/**
* Customized Chakra FormControl + NumberInput multi-part component.
*/
const SDNumberInput = (props: Props) => {
const {
label,
@ -31,7 +34,7 @@ const SDNumberInput = (props: Props) => {
<Flex gap={2} justifyContent={'space-between'} alignItems={'center'}>
{label && (
<FormLabel marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
<Text fontSize={fontSize} whiteSpace="nowrap">
{label}
</Text>
</FormLabel>
@ -42,7 +45,7 @@ const SDNumberInput = (props: Props) => {
keepWithinRange={false}
clampValueOnBlur={true}
>
<NumberInputField fontSize={'md'}/>
<NumberInputField fontSize={'md'} />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />

View File

@ -0,0 +1,56 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
/**
* Customized Chakra FormControl + Select multi-part component.
*/
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel marginBottom={marginBottom}>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' || typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -11,6 +11,9 @@ interface Props extends SwitchProps {
width?: string | number;
}
/**
* Customized Chakra FormControl + Switch multi-part component.
*/
const SDSwitch = (props: Props) => {
const {
label,
@ -28,7 +31,7 @@ const SDSwitch = (props: Props) => {
fontSize={fontSize}
marginBottom={1}
flexGrow={2}
whiteSpace='nowrap'
whiteSpace="nowrap"
>
{label}
</FormLabel>

View File

@ -0,0 +1,104 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from '../../features/system/systemSlice';
import { validateSeedWeights } from '../util/seedWeightPairs';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
prompt: options.prompt,
shouldGenerateVariations: options.shouldGenerateVariations,
seedWeights: options.seedWeights,
maskPath: options.maskPath,
initialImagePath: options.initialImagePath,
seed: options.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Checks relevant pieces of state to confirm generation will not deterministically fail.
* This is used to prevent the 'Generate' button from being clicked.
*/
const useCheckParameters = (): boolean => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(optionsSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -0,0 +1,182 @@
/*
These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa.
*/
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from '../../features/system/systemSlice';
import {
seedWeightsToString,
stringToSeedWeightsArray,
} from './seedWeightPairs';
import randomInt from './randomInt';
export const frontendToBackendParameters = (
optionsState: OptionsState,
systemState: SystemState
): { [key: string]: any } => {
const {
prompt,
iterations,
steps,
cfgScale,
height,
width,
sampler,
seed,
seamless,
shouldUseInitImage,
img2imgStrength,
initialImagePath,
maskPath,
shouldFitToWidthHeight,
shouldGenerateVariations,
variationAmount,
seedWeights,
shouldRunESRGAN,
upscalingLevel,
upscalingStrength,
shouldRunGFPGAN,
gfpganStrength,
shouldRandomizeSeed,
} = optionsState;
const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = {
prompt,
iterations,
steps,
cfg_scale: cfgScale,
height,
width,
sampler_name: sampler,
seed,
seamless,
progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variationAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeightsArray(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
};
export const backendToFrontendParameters = (parameters: {
[key: string]: any;
}) => {
const {
prompt,
iterations,
steps,
cfg_scale,
height,
width,
sampler_name,
seed,
seamless,
progress_images,
variation_amount,
with_variations,
gfpgan_strength,
upscale,
init_img,
init_mask,
strength,
} = parameters;
const options: { [key: string]: any } = {
shouldDisplayInProgress: progress_images,
// init
shouldGenerateVariations: false,
shouldRunESRGAN: false,
shouldRunGFPGAN: false,
initialImagePath: '',
maskPath: '',
};
if (variation_amount > 0) {
options.shouldGenerateVariations = true;
options.variationAmount = variation_amount;
if (with_variations) {
options.seedWeights = seedWeightsToString(with_variations);
}
}
if (gfpgan_strength > 0) {
options.shouldRunGFPGAN = true;
options.gfpganStrength = gfpgan_strength;
}
if (upscale) {
options.shouldRunESRGAN = true;
options.upscalingLevel = upscale[0];
options.upscalingStrength = upscale[1];
}
if (init_img) {
options.shouldUseInitImage = true;
options.initialImagePath = init_img;
options.strength = strength;
if (init_mask) {
options.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
options.prompt = prompt;
options.iterations = iterations;
options.steps = steps;
options.cfgScale = cfg_scale;
options.height = height;
options.width = width;
options.sampler = sampler_name;
options.seed = seed;
options.seamless = seamless;
}
return options;
};

View File

@ -0,0 +1,16 @@
import * as InvokeAI from '../../app/invokeai';
const promptToString = (prompt: InvokeAI.Prompt): string => {
if (prompt.length === 1) {
return prompt[0].prompt;
}
return prompt
.map(
(promptItem: InvokeAI.PromptItem): string =>
`${promptItem.prompt}:${promptItem.weight}`
)
.join(' ');
};
export default promptToString;

View File

@ -0,0 +1,68 @@
import * as InvokeAI from '../../app/invokeai';
export const stringToSeedWeights = (
string: string
): InvokeAI.SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
return { seed: parseInt(p[0]), weight: parseFloat(p[1]) };
});
if (!validateSeedWeights(pairs)) {
return false;
}
return pairs;
};
export const validateSeedWeights = (
seedWeights: InvokeAI.SeedWeights | string
): boolean => {
return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights))
: Boolean(
seedWeights.length &&
!seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
const { seed, weight } = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 &&
weight <= 1;
return !(isSeedValid && isWeightValid);
})
);
};
export const seedWeightsToString = (
seedWeights: InvokeAI.SeedWeights
): string => {
return seedWeights.reduce((acc, pair, i, arr) => {
const { seed, weight } = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
}
return acc;
}, '');
};
export const seedWeightsToArray = (
seedWeights: InvokeAI.SeedWeights
): Array<Array<number>> => {
return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
pair.seed,
pair.weight,
]);
};
export const stringToSeedWeightsArray = (
string: string
): Array<Array<number>> => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
return arrPairs.map(
(p: Array<string>): Array<number> => [parseInt(p[0]), parseFloat(p[1])]
);
};

View File

@ -1,16 +0,0 @@
import { Button, ButtonProps } from '@chakra-ui/react';
interface Props extends ButtonProps {
label: string;
}
const SDButton = (props: Props) => {
const { label, size = 'sm', ...rest } = props;
return (
<Button size={size} {...rest}>
{label}
</Button>
);
};
export default SDButton;

View File

@ -1,57 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
Select,
SelectProps,
Text,
} from '@chakra-ui/react';
interface Props extends SelectProps {
label: string;
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
}
const SDSelect = (props: Props) => {
const {
label,
isDisabled,
validValues,
size = 'sm',
fontSize = 'md',
marginBottom = 1,
whiteSpace = 'nowrap',
...rest
} = props;
return (
<FormControl isDisabled={isDisabled}>
<Flex justifyContent={'space-between'} alignItems={'center'}>
<FormLabel
marginBottom={marginBottom}
>
<Text fontSize={fontSize} whiteSpace={whiteSpace}>
{label}
</Text>
</FormLabel>
<Select fontSize={fontSize} size={size} {...rest}>
{validValues.map((opt) => {
return typeof opt === 'string' ||
typeof opt === 'number' ? (
<option key={opt} value={opt}>
{opt}
</option>
) : (
<option key={opt.value} value={opt.value}>
{opt.key}
</option>
);
})}
</Select>
</Flex>
</FormControl>
);
};
export default SDSelect;

View File

@ -1,161 +0,0 @@
import { Center, Flex, Image, useColorModeValue } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import DeleteImageModalButton from './DeleteImageModalButton';
import SDButton from '../../components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
const height = 'calc(100vh - 238px)';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const CurrentImage = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return (
<Flex direction={'column'} rounded={'md'} borderWidth={1} p={2} gap={2}>
{imageToDisplay && (
<Flex gap={2}>
<SDButton
label='Use as initial image'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setInitialImagePath(imageToDisplay.url))
}
/>
<SDButton
label='Use all'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={() =>
dispatch(setAllParameters(imageToDisplay.metadata))
}
/>
<SDButton
label='Use seed'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!imageToDisplay.metadata.seed}
onClick={() =>
dispatch(setSeed(imageToDisplay.metadata.seed!))
}
/>
<SDButton
label='Upscale'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runESRGAN(imageToDisplay))}
/>
<SDButton
label='Fix faces'
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing)
}
onClick={() => dispatch(runGFPGAN(imageToDisplay))}
/>
<SDButton
label='Details'
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={() =>
setShouldShowImageDetails(!shouldShowImageDetails)
}
/>
<DeleteImageModalButton image={imageToDisplay}>
<SDButton
label='Delete'
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModalButton>
</Flex>
)}
<Center height={height} position={'relative'}>
{imageToDisplay && (
<Image
src={imageToDisplay.url}
fit='contain'
maxWidth={'100%'}
maxHeight={'100%'}
/>
)}
{imageToDisplay && shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing='border-box'
backgroundColor={bgColor}
overflow='scroll'
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
);
};
export default CurrentImage;

View File

@ -0,0 +1,155 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
setAllParameters,
setInitialImagePath,
setSeed,
} from '../options/optionsSlice';
import DeleteImageModal from './DeleteImageModal';
import { SystemState } from '../system/systemSlice';
import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
type CurrentImageButtonsProps = {
image: InvokeAI.Image;
shouldShowImageDetails: boolean;
setShouldShowImageDetails: (b: boolean) => void;
};
/**
* Row of buttons for common actions:
* Use as init image, use all params, use seed, upscale, fix faces, details, delete.
*/
const CurrentImageButtons = ({
image,
shouldShowImageDetails,
setShouldShowImageDetails,
}: CurrentImageButtonsProps) => {
const dispatch = useAppDispatch();
const { intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const { upscalingLevel, gfpganStrength } = useAppSelector(
(state: RootState) => state.options
);
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
useAppSelector(systemSelector);
const handleClickUseAsInitialImage = () =>
dispatch(setInitialImagePath(image.url));
const handleClickUseAllParameters = () =>
dispatch(setAllParameters(image.metadata));
// Non-null assertion: this button is disabled if there is no seed.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed));
const handleClickUpscale = () => dispatch(runESRGAN(image));
const handleClickFixFaces = () => dispatch(runGFPGAN(image));
const handleClickShowImageDetails = () =>
setShouldShowImageDetails(!shouldShowImageDetails);
return (
<Flex gap={2}>
<SDButton
label="Use as initial image"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
onClick={handleClickUseAsInitialImage}
/>
<SDButton
label="Use all"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
onClick={handleClickUseAllParameters}
/>
<SDButton
label="Use seed"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={!image.metadata.image.seed}
onClick={handleClickUseSeed}
/>
<SDButton
label="Upscale"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isESRGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
onClick={handleClickUpscale}
/>
<SDButton
label="Fix faces"
colorScheme={'gray'}
flexGrow={1}
variant={'outline'}
isDisabled={
!isGFPGANAvailable ||
Boolean(intermediateImage) ||
!(isConnected && !isProcessing) ||
!gfpganStrength
}
onClick={handleClickFixFaces}
/>
<SDButton
label="Details"
colorScheme={'gray'}
variant={shouldShowImageDetails ? 'solid' : 'outline'}
borderWidth={1}
flexGrow={1}
onClick={handleClickShowImageDetails}
/>
<DeleteImageModal image={image}>
<SDButton
label="Delete"
colorScheme={'red'}
flexGrow={1}
variant={'outline'}
isDisabled={Boolean(intermediateImage)}
/>
</DeleteImageModal>
</Flex>
);
};
export default CurrentImageButtons;

View File

@ -0,0 +1,67 @@
import { Center, Flex, Image, Text, useColorModeValue } from '@chakra-ui/react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { useState } from 'react';
import ImageMetadataViewer from './ImageMetadataViewer';
import CurrentImageButtons from './CurrentImageButtons';
// TODO: With CSS Grid I had a hard time centering the image in a grid item. This is needed for that.
const height = 'calc(100vh - 238px)';
/**
* Displays the current image if there is one, plus associated actions.
*/
const CurrentImageDisplay = () => {
const { currentImage, intermediateImage } = useAppSelector(
(state: RootState) => state.gallery
);
const bgColor = useColorModeValue(
'rgba(255, 255, 255, 0.85)',
'rgba(0, 0, 0, 0.8)'
);
const [shouldShowImageDetails, setShouldShowImageDetails] =
useState<boolean>(false);
const imageToDisplay = intermediateImage || currentImage;
return imageToDisplay ? (
<Flex direction={'column'} borderWidth={1} rounded={'md'} p={2} gap={2}>
<CurrentImageButtons
image={imageToDisplay}
shouldShowImageDetails={shouldShowImageDetails}
setShouldShowImageDetails={setShouldShowImageDetails}
/>
<Center height={height} position={'relative'}>
<Image
src={imageToDisplay.url}
fit="contain"
maxWidth={'100%'}
maxHeight={'100%'}
/>
{shouldShowImageDetails && (
<Flex
width={'100%'}
height={'100%'}
position={'absolute'}
top={0}
left={0}
p={3}
boxSizing="border-box"
backgroundColor={bgColor}
overflow="scroll"
>
<ImageMetadataViewer image={imageToDisplay} />
</Flex>
)}
</Center>
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No image selected</Text>
</Center>
);
};
export default CurrentImageDisplay;

View File

@ -0,0 +1,125 @@
import {
Text,
AlertDialog,
AlertDialogBody,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogContent,
AlertDialogOverlay,
useDisclosure,
Button,
Switch,
FormControl,
FormLabel,
Flex,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
ChangeEvent,
cloneElement,
forwardRef,
ReactElement,
SyntheticEvent,
useRef,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import * as InvokeAI from '../../app/invokeai';
interface DeleteImageModalProps {
/**
* Component which, on click, should delete the image/open the modal.
*/
children: ReactElement;
/**
* The image to delete.
*/
image: InvokeAI.Image;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/**
* Needs a child, which will act as the button to delete an image.
* If system.shouldConfirmOnDelete is true, a confirmation modal is displayed.
* If it is false, the image is deleted immediately.
* The confirmation modal has a "Don't ask me again" switch to set the boolean.
*/
const DeleteImageModal = forwardRef(
({ image, children }: DeleteImageModalProps, ref) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const cancelRef = useRef<HTMLButtonElement>(null);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleChangeShouldConfirmOnDelete = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
return (
<>
{cloneElement(children, {
// TODO: This feels wrong.
onClick: handleClickDelete,
ref: ref,
})}
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete image
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction={'column'} gap={5}>
<Text>
Are you sure? You can't undo this action afterwards.
</Text>
<FormControl>
<Flex alignItems={'center'}>
<FormLabel mb={0}>Don't ask me again</FormLabel>
<Switch
checked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</FormControl>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
Cancel
</Button>
<Button colorScheme="red" onClick={handleDelete} ml={3}>
Delete
</Button>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
</>
);
}
);
export default DeleteImageModal;

View File

@ -1,94 +0,0 @@
import {
IconButtonProps,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
cloneElement,
ReactElement,
SyntheticEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { deleteImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice';
interface Props extends IconButtonProps {
image: SDImage;
'aria-label': string;
children: ReactElement;
}
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.shouldConfirmOnDelete
);
/*
TODO: The modal and button to open it should be two different components,
but their state is closely related and I'm not sure how best to accomplish it.
*/
const DeleteImageModalButton = (props: Omit<Props, 'aria-label'>) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete();
};
const { image, children } = props;
const handleDelete = () => {
dispatch(deleteImage(image));
onClose();
};
const handleDeleteAndDontAsk = () => {
dispatch(deleteImage(image));
dispatch(setShouldConfirmOnDelete(false));
onClose();
};
return (
<>
{cloneElement(children, {
onClick: handleClickDelete,
})}
<Modal isOpen={isOpen} onClose={onClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Are you sure you want to delete this image?</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Text>It will be deleted forever!</Text>
</ModalBody>
<ModalFooter justifyContent={'space-between'}>
<SDButton label={'Yes'} colorScheme='red' onClick={handleDelete} />
<SDButton
label={"Yes, and don't ask me again"}
colorScheme='red'
onClick={handleDeleteAndDontAsk}
/>
<SDButton label='Cancel' colorScheme='blue' onClick={onClose} />
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
export default DeleteImageModalButton;

View File

@ -0,0 +1,143 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
Tooltip,
useColorModeValue,
} from '@chakra-ui/react';
import { useAppDispatch } from '../../app/store';
import { setCurrentImage } from './gallerySlice';
import { FaCheck, FaSeedling, FaTrashAlt } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../options/optionsSlice';
import * as InvokeAI from '../../app/invokeai';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
interface HoverableImageProps {
image: InvokeAI.Image;
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = memo((props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.image.seed));
};
const handleClickImage = () => dispatch(setCurrentImage(image));
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit="cover"
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width="100%"
height="100%"
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={handleClickImage}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon fill={checkColor} width={'50%'} height={'50%'} as={FaCheck} />
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<Tooltip label={'Delete image'}>
<DeleteImageModal image={image}>
<IconButton
colorScheme="red"
aria-label="Delete image"
icon={<FaTrashAlt />}
size="xs"
variant={'imageHoverIconButton'}
fontSize={14}
/>
</DeleteImageModal>
</Tooltip>
{['txt2img', 'img2img'].includes(image.metadata.image.type) && (
<Tooltip label="Use all parameters">
<IconButton
aria-label="Use all parameters"
icon={<IoArrowUndoCircleOutline />}
size="xs"
fontSize={18}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetAllParameters}
/>
</Tooltip>
)}
{image.metadata.image.seed && (
<Tooltip label="Use seed">
<IconButton
aria-label="Use seed"
icon={<FaSeedling />}
size="xs"
fontSize={16}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetSeed}
/>
</Tooltip>
)}
</Flex>
)}
</Flex>
</Box>
);
}, memoEqualityCheck);
export default HoverableImage;

View File

@ -0,0 +1,39 @@
import { Center, Flex, Text } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppSelector } from '../../app/store';
import HoverableImage from './HoverableImage';
/**
* Simple image gallery.
*/
const ImageGallery = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
/**
* I don't like that this needs to rerender whenever the current image is changed.
* What if we have a large number of images? I suppose pagination (planned) will
* mitigate this issue.
*
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/
return images.length ? (
<Flex gap={2} wrap="wrap" pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage key={uuid} image={image} isSelected={isSelected} />
);
})}
</Flex>
) : (
<Center height={'100%'} position={'relative'}>
<Text size={'xl'}>No images in gallery</Text>
</Center>
);
};
export default ImageGallery;

View File

@ -1,124 +1,326 @@
import {
Center,
Flex,
IconButton,
Link,
List,
ListItem,
Text,
Box,
Center,
Flex,
IconButton,
Link,
Text,
Tooltip,
useColorModeValue,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/hooks';
import SDButton from '../../components/SDButton';
import { setAllParameters, setParameter } from '../sd/sdSlice';
import { SDImage, SDMetadata } from './gallerySlice';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { memo } from 'react';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { useAppDispatch } from '../../app/store';
import * as InvokeAI from '../../app/invokeai';
import {
setCfgScale,
setGfpganStrength,
setHeight,
setImg2imgStrength,
setInitialImagePath,
setMaskPath,
setPrompt,
setSampler,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setUpscalingLevel,
setUpscalingStrength,
setWidth,
} from '../options/optionsSlice';
import promptToString from '../../common/util/promptToString';
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
import { FaCopy } from 'react-icons/fa';
type Props = {
image: SDImage;
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
};
const ImageMetadataViewer = ({ image }: Props) => {
const dispatch = useAppDispatch();
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({ label, value, onClick, isLink }: MetadataItemProps) => {
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label="Use this parameter"
icon={<IoArrowUndoCircleOutline />}
size={'xs'}
variant={'ghost'}
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
<Text fontWeight={'semibold'} whiteSpace={'nowrap'}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak={'break-all'}>
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
{value.toString()}
</Text>
)}
</Flex>
);
};
const keys = Object.keys(PARAMETERS);
type ImageMetadataViewerProps = {
image: InvokeAI.Image;
};
const metadata: Array<{
label: string;
key: string;
value: string | number | boolean;
}> = [];
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
keys.forEach((key) => {
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
// TODO: Show more interesting information in this component.
return (
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}>
<SDButton
label='Use all parameters'
colorScheme={'gray'}
padding={2}
isDisabled={metadata.length === 0}
onClick={() => dispatch(setAllParameters(image.metadata))}
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
const metadata = image.metadata.image;
const {
type,
postprocessing,
sampler,
prompt,
seed,
variations,
steps,
cfg_scale,
seamless,
width,
height,
strength,
fit,
init_image_path,
mask_image_path,
orig_path,
scale,
} = metadata;
const metadataJSON = JSON.stringify(metadata, null, 2);
return (
<Flex
gap={1}
direction={'column'}
overflowY={'scroll'}
width={'100%'}
>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
{image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(metadata).length ? (
<>
{type && <MetadataItem label="Type" value={type} />}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} isLink />
)}
{type === 'gfpgan' && strength && (
<MetadataItem
label="Fix faces strength"
value={strength}
onClick={() => dispatch(setGfpganStrength(strength))}
/>
<Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal>
<Text>{image.url}</Text>
</Link>
</Flex>
{metadata.length ? (
<>
<List>
{metadata.map((parameter, i) => {
const { label, key, value } = parameter;
return (
<ListItem key={i} pb={1}>
<Flex gap={2}>
<IconButton
aria-label='Use this parameter'
icon={<FaPlus />}
size={'xs'}
onClick={() =>
dispatch(
setParameter({
key,
value,
})
)
}
/>
<Text fontWeight={'semibold'}>
{label}:
</Text>
{value === undefined ||
value === null ||
value === '' ||
value === 0 ? (
<Text
maxHeight={100}
fontStyle={'italic'}
>
None
</Text>
) : (
<Text
maxHeight={100}
overflowY={'scroll'}
>
{value.toString()}
</Text>
)}
</Flex>
</ListItem>
);
})}
</List>
<Flex gap={2}>
<Text fontWeight={'semibold'}>Raw:</Text>
<Text
maxHeight={100}
overflowY={'scroll'}
wordBreak={'break-all'}
>
{JSON.stringify(image.metadata)}
</Text>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight='semibold'>
No metadata available
</Text>
</Center>
)}
{type === 'esrgan' && scale && (
<MetadataItem
label="Upscaling scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
)}
{type === 'esrgan' && strength && (
<MetadataItem
label="Upscaling strength"
value={strength}
onClick={() => dispatch(setUpscalingStrength(strength))}
/>
)}
{prompt && (
<MetadataItem
label="Prompt"
value={promptToString(prompt)}
onClick={() => dispatch(setPrompt(prompt))}
/>
)}
{seed && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
/>
)}
{sampler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
/>
)}
{steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
/>
)}
{cfg_scale && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
/>
)}
{variations && variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
}
/>
)}
{seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setWidth(seamless))}
/>
)}
{width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImagePath(init_image_path))}
/>
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing &&
postprocessing.length > 0 &&
postprocessing.map(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
return (
<>
<MetadataItem
label="Upscaling scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Upscaling strength"
value={strength}
onClick={() => dispatch(setUpscalingStrength(strength))}
/>
</>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<MetadataItem
label="Fix faces strength"
value={strength}
onClick={() => dispatch(setGfpganStrength(strength))}
/>
);
}
}
)}
</Flex>
);
};
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<Tooltip label={`Copy JSON`}>
<IconButton
aria-label="Copy JSON"
icon={<FaCopy />}
size={'xs'}
variant={'ghost'}
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight={'semibold'}>JSON:</Text>
</Flex>
<Box
// maxHeight={200}
overflow={'scroll'}
flexGrow={3}
wordBreak={'break-all'}
bgColor={jsonBgColor}
padding={2}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width={'100%'} pt={10}>
<Text fontSize={'lg'} fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
export default ImageMetadataViewer;

View File

@ -1,150 +0,0 @@
import {
Box,
Flex,
Icon,
IconButton,
Image,
useColorModeValue,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDImage, setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa';
import DeleteImageModalButton from './DeleteImageModalButton';
import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice';
interface HoverableImageProps {
image: SDImage;
isSelected: boolean;
}
const HoverableImage = memo(
(props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const checkColor = useColorModeValue('green.600', 'green.300');
const bgColor = useColorModeValue('gray.200', 'gray.700');
const bgGradient = useColorModeValue(
'radial-gradient(circle, rgba(255,255,255,0.7) 0%, rgba(255,255,255,0.7) 20%, rgba(0,0,0,0) 100%)',
'radial-gradient(circle, rgba(0,0,0,0.7) 0%, rgba(0,0,0,0.7) 20%, rgba(0,0,0,0) 100%)'
);
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setSeed(image.metadata.seed!)); // component not rendered unless this exists
};
return (
<Box position={'relative'} key={uuid}>
<Image
width={120}
height={120}
objectFit='cover'
rounded={'md'}
src={url}
loading={'lazy'}
backgroundColor={bgColor}
/>
<Flex
cursor={'pointer'}
position={'absolute'}
top={0}
left={0}
rounded={'md'}
width='100%'
height='100%'
alignItems={'center'}
justifyContent={'center'}
background={isSelected ? bgGradient : undefined}
onClick={() => dispatch(setCurrentImage(image))}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
{isSelected && (
<Icon
fill={checkColor}
width={'50%'}
height={'50%'}
as={FaCheck}
/>
)}
{isHovered && (
<Flex
direction={'column'}
gap={1}
position={'absolute'}
top={1}
right={1}
>
<DeleteImageModalButton image={image}>
<IconButton
colorScheme='red'
aria-label='Delete image'
icon={<FaTrash />}
size='xs'
fontSize={15}
/>
</DeleteImageModalButton>
<IconButton
aria-label='Use all parameters'
colorScheme={'blue'}
icon={<FaCopy />}
size='xs'
fontSize={15}
onClickCapture={handleClickSetAllParameters}
/>
{image.metadata.seed && (
<IconButton
aria-label='Use seed'
colorScheme={'blue'}
icon={<FaSeedling />}
size='xs'
fontSize={16}
onClickCapture={handleClickSetSeed}
/>
)}
</Flex>
)}
</Flex>
</Box>
);
},
(prev, next) =>
prev.image.uuid === next.image.uuid &&
prev.isSelected === next.isSelected
);
const ImageRoll = () => {
const { images, currentImageUuid } = useAppSelector(
(state: RootState) => state.gallery
);
return (
<Flex gap={2} wrap='wrap' pb={2}>
{[...images].reverse().map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={uuid}
image={image}
isSelected={isSelected}
/>
);
})}
</Flex>
);
};
export default ImageRoll;

View File

@ -1,40 +1,13 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { UpscalingLevel } from '../sd/sdSlice';
import { backendToFrontendParameters } from '../../app/parameterTranslation';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
prompt?: string;
steps?: number;
cfgScale?: number;
height?: number;
width?: number;
sampler?: string;
seed?: number;
img2imgStrength?: number;
gfpganStrength?: number;
upscalingLevel?: UpscalingLevel;
upscalingStrength?: number;
initialImagePath?: string;
maskPath?: string;
seamless?: boolean;
shouldFitToWidthHeight?: boolean;
}
export interface SDImage {
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
uuid: string;
url: string;
metadata: SDMetadata;
}
import { clamp } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
export interface GalleryState {
currentImage?: InvokeAI.Image;
currentImageUuid: string;
images: Array<SDImage>;
intermediateImage?: SDImage;
currentImage?: SDImage;
images: Array<InvokeAI.Image>;
intermediateImage?: InvokeAI.Image;
}
const initialState: GalleryState = {
@ -46,99 +19,84 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
setCurrentImage: (state, action: PayloadAction<SDImage>) => {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
removeImage: (state, action: PayloadAction<SDImage>) => {
const { uuid } = action.payload;
removeImage: (state, action: PayloadAction<string>) => {
const uuid = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid);
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
if (uuid === state.currentImageUuid) {
/**
* We are deleting the currently selected image.
*
* We want the new currentl selected image to be under the cursor in the
* gallery, so we need to do some fanagling. The currently selected image
* is set by its UUID, not its index in the image list.
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex(
(image) => image.uuid === uuid
);
const newCurrentImageIndex = Math.min(
Math.max(imageToDeleteIndex, 0),
newImages.length - 1
);
/**
* New current image needs to be in the same spot, but because the gallery
* is sorted in reverse order, the new current image's index will actuall be
* one less than the deleted image's index.
*
* Clamp the new index to ensure it is valid..
*/
const newCurrentImageIndex = clamp(
imageToDeleteIndex - 1,
0,
newImages.length - 1
);
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
}
state.images = newImages;
state.currentImage = newImages.length
? newImages[newCurrentImageIndex]
: undefined;
state.currentImageUuid = newImages.length
? newImages[newCurrentImageIndex].uuid
: '';
},
addImage: (state, action: PayloadAction<SDImage>) => {
addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.images.push(action.payload);
state.currentImageUuid = action.payload.uuid;
state.intermediateImage = undefined;
state.currentImage = action.payload;
},
setIntermediateImage: (state, action: PayloadAction<SDImage>) => {
setIntermediateImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.intermediateImage = action.payload;
},
clearIntermediateImage: (state) => {
state.intermediateImage = undefined;
},
setGalleryImages: (
state,
action: PayloadAction<
Array<{
path: string;
metadata: { [key: string]: string | number | boolean };
}>
>
) => {
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
const images = action.payload;
if (images.length === 0) {
// there are no images on disk, clear the gallery
state.images = [];
state.currentImageUuid = '';
state.currentImage = undefined;
} else {
// Filter image urls that are already in the rehydrated state
const filteredImages = action.payload.filter(
(image) => !state.images.find((i) => i.url === image.path)
);
const preparedImages = filteredImages.map((image): SDImage => {
return {
uuid: uuidv4(),
url: image.path,
metadata: backendToFrontendParameters(image.metadata),
};
});
const newImages = [...state.images].concat(preparedImages);
// if previous currentimage no longer exists, set a new one
if (!newImages.find((image) => image.uuid === state.currentImageUuid)) {
const newCurrentImage = newImages[newImages.length - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
setGalleryImages: (state, action: PayloadAction<Array<InvokeAI.Image>>) => {
const newImages = action.payload;
if (newImages.length) {
const newCurrentImage = newImages[newImages.length - 1];
state.images = newImages;
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
},
},
});
export const {
setCurrentImage,
removeImage,
addImage,
clearIntermediateImage,
removeImage,
setCurrentImage,
setGalleryImages,
setIntermediateImage,
clearIntermediateImage,
} = gallerySlice.actions;
export default gallerySlice.reducer;

View File

@ -1,35 +0,0 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
realSteps: sd.realSteps,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProgressBar = () => {
const { realSteps } = useAppSelector(sdSelector);
const { currentStep } = useAppSelector((state: RootState) => state.system);
const progress = Math.round((currentStep * 100) / realSteps);
return (
<Progress
height='10px'
value={progress}
isIndeterminate={progress < 0 || currentStep === realSteps}
/>
);
};
export default ProgressBar;

View File

@ -1,93 +0,0 @@
import {
Flex,
Heading,
IconButton,
Link,
Spacer,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { isConnected: system.isConnected };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const { isConnected } = useAppSelector(systemSelector);
return (
<Flex minWidth='max-content' alignItems='center' gap='1' pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading>
<Spacer />
<Text textColor={isConnected ? 'green.500' : 'red.500'}>
{isConnected ? `Connected to server` : 'No connection to server'}
</Text>
<SettingsModal>
<IconButton
aria-label='Settings'
variant='link'
fontSize={24}
size={'sm'}
icon={<MdSettings />}
/>
</SettingsModal>
<IconButton
aria-label='Link to Github Issues'
variant='link'
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href='http://github.com/lstein/stable-diffusion/issues'
>
<MdHelp />
</Link>
}
/>
<IconButton
aria-label='Link to Github Repo'
variant='link'
fontSize={20}
size={'sm'}
icon={
<Link isExternal href='http://github.com/lstein/stable-diffusion'>
<FaGithub />
</Link>
}
/>
<IconButton
aria-label='Toggle Dark Mode'
onClick={toggleColorMode}
variant='link'
size={'sm'}
fontSize={colorMode == 'light' ? 18 : 20}
icon={colorMode == 'light' ? <FaMoon /> : <FaSun />}
/>
</Flex>
);
};
export default SiteHeader;

View File

@ -0,0 +1,87 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
OptionsState,
} from '../options/optionsSlice';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
upscalingLevel: options.upscalingLevel,
upscalingStrength: options.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const ESRGANOptions = () => {
const dispatch = useAppDispatch();
const { upscalingLevel, upscalingStrength } = useAppSelector(optionsSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
const handleChangeStrength = (v: string | number) =>
dispatch(setUpscalingStrength(Number(v)));
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label="Scale"
value={upscalingLevel}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -0,0 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { OptionsState, setGfpganStrength } from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
gfpganStrength: options.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const GFPGANOptions = () => {
const dispatch = useAppDispatch();
const { gfpganStrength } = useAppSelector(optionsSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const handleChangeStrength = (v: string | number) =>
dispatch(setGfpganStrength(Number(v)));
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label="Strength"
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -0,0 +1,59 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { ChangeEvent } from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage';
import {
OptionsState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './optionsSlice';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
img2imgStrength: options.img2imgStrength,
shouldFitToWidthHeight: options.shouldFitToWidthHeight,
};
}
);
/**
* Options for img2img generation (strength, fit, init/mask upload).
*/
const ImageToImageOptions = () => {
const dispatch = useAppDispatch();
const { img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(optionsSelector);
const handleChangeStrength = (v: string | number) =>
dispatch(setImg2imgStrength(Number(v)));
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldFitToWidthHeight(e.target.checked));
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
label="Strength"
step={0.01}
min={0}
max={1}
onChange={handleChangeStrength}
value={img2imgStrength}
/>
<SDSwitch
label="Fit initial image to output size"
isChecked={shouldFitToWidthHeight}
onChange={handleChangeFit}
/>
<InitAndMaskImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -0,0 +1,64 @@
import { Box } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
type ImageUploaderProps = {
/**
* Component which, on click, should open the upload interface.
*/
children: ReactElement;
/**
* Callback to handle uploading the selected file.
*/
fileAcceptedCallback: (file: File) => void;
/**
* Callback to handle a file being rejected.
*/
fileRejectionCallback: (rejection: FileRejection) => void;
};
/**
* File upload using react-dropzone.
* Needs a child to be the button to activate the upload interface.
*/
const ImageUploader = ({
children,
fileAcceptedCallback,
fileRejectionCallback,
}: ImageUploaderProps) => {
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
fileRejectionCallback(rejection);
});
acceptedFiles.forEach((file: File) => {
fileAcceptedCallback(file);
});
},
[fileAcceptedCallback, fileRejectionCallback]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<Box {...getRootProps()} flexGrow={3}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</Box>
);
};
export default ImageUploader;

View File

@ -0,0 +1,57 @@
import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { OptionsState } from '../../features/options/optionsSlice';
import './InitAndMaskImage.css';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
maskPath: options.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
/**
* Displays init and mask images and buttons to upload/delete them.
*/
const InitAndMaskImage = () => {
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
return (
<Flex direction={'column'} alignItems={'center'} gap={2}>
<InitAndMaskUploadButtons setShouldShowMask={setShouldShowMask} />
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitAndMaskImage;

View File

@ -0,0 +1,151 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react';
import { FaTrash, FaUpload } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import {
OptionsState,
setInitialImagePath,
setMaskPath,
} from '../../features/options/optionsSlice';
import {
uploadInitialImage,
uploadMaskImage,
} from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader';
import { FileRejection } from 'react-dropzone';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
maskPath: options.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
type InitAndMaskUploadButtonsProps = {
setShouldShowMask: (b: boolean) => void;
};
/**
* Init and mask image upload buttons.
*/
const InitAndMaskUploadButtons = ({
setShouldShowMask,
}: InitAndMaskUploadButtonsProps) => {
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
// Use a toast to alert user when a file upload is rejected
const toast = useToast();
// Clear the init and mask images
const handleClickResetInitialImage = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
};
// Clear the init and mask images
const handleClickResetMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setMaskPath(''));
};
// Handle hover to view initial image and mask image
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
// Callbacks to for handling file upload attempts
const initImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadInitialImage(file)),
[dispatch]
);
const maskImageFileAcceptedCallback = useCallback(
(file: File) => dispatch(uploadMaskImage(file)),
[dispatch]
);
const fileRejectionCallback = useCallback(
(rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
},
[toast]
);
return (
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<ImageUploader
fileAcceptedCallback={initImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
>
Image
</Button>
</ImageUploader>
<IconButton
isDisabled={!initialImagePath}
size={'sm'}
aria-label={'Reset mask'}
onClick={handleClickResetInitialImage}
icon={<FaTrash />}
/>
<ImageUploader
fileAcceptedCallback={maskImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback}
>
<Button
isDisabled={!initialImagePath}
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
>
Mask
</Button>
</ImageUploader>
<IconButton
isDisabled={!maskPath}
size={'sm'}
aria-label={'Reset mask'}
onClick={handleClickResetMask}
icon={<FaTrash />}
/>
</Flex>
);
};
export default InitAndMaskUploadButtons;

View File

@ -0,0 +1,217 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
ExpandedIndex,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
OptionsState,
setShouldUseInitImage,
} from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice';
import SeedVariationOptions from './SeedVariationOptions';
import SamplerOptions from './SamplerOptions';
import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
import { ChangeEvent } from 'react';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
initialImagePath: options.initialImagePath,
shouldUseInitImage: options.shouldUseInitImage,
shouldRunESRGAN: options.shouldRunESRGAN,
shouldRunGFPGAN: options.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Main container for generation and processing parameters.
*/
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(optionsSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
/**
* Stores accordion state in redux so preferred UI setup is retained.
*/
const handleChangeAccordionState = (openAccordions: ExpandedIndex) =>
dispatch(setOpenAccordions(openAccordions));
const handleChangeShouldRunESRGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunESRGAN(e.target.checked));
const handleChangeShouldRunGFPGAN = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRunGFPGAN(e.target.checked));
const handleChangeShouldUseInitImage = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseInitImage(e.target.checked));
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={handleChangeAccordionState}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={handleChangeShouldRunESRGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={handleChangeShouldRunGFPGAN}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={handleChangeShouldUseInitImage}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex="1" textAlign="left">
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -0,0 +1,76 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, OptionsState } from '../options/optionsSlice';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
height: options.height,
width: options.width,
seamless: options.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Image output options. Includes width, height, seamless tiling.
*/
const OutputOptions = () => {
const dispatch = useAppDispatch();
const { height, width, seamless } = useAppSelector(optionsSelector);
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setWidth(Number(e.target.value)));
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setHeight(Number(e.target.value)));
const handleChangeSeamless = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeamless(e.target.checked));
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label="Width"
value={width}
flexGrow={1}
onChange={handleChangeWidth}
validValues={WIDTHS}
/>
<SDSelect
label="Height"
value={height}
flexGrow={1}
onChange={handleChangeHeight}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label="Seamless tiling"
fontSize={'md'}
isChecked={seamless}
onChange={handleChangeSeamless}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -0,0 +1,68 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { cancelProcessing, generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import SDButton from '../../common/components/SDButton';
import useCheckParameters from '../../common/hooks/useCheckParameters';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Buttons to start and cancel image generation.
*/
const ProcessButtons = () => {
const dispatch = useAppDispatch();
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const isReady = useCheckParameters();
const handleClickGenerate = () => dispatch(generateImage());
const handleClickCancel = () => dispatch(cancelProcessing());
return (
<Flex
gap={2}
direction={'column'}
alignItems={'space-between'}
height={'100%'}
>
<SDButton
label="Generate"
type="submit"
colorScheme="green"
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={handleClickGenerate}
/>
<SDButton
label="Cancel"
colorScheme="red"
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={handleClickCancel}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -0,0 +1,44 @@
import { Textarea } from '@chakra-ui/react';
import {
ChangeEvent,
KeyboardEvent,
} from 'react';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store';
import { setPrompt } from '../options/optionsSlice';
/**
* Prompt input text area.
*/
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.options);
const dispatch = useAppDispatch();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>
dispatch(setPrompt(e.target.value));
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false) {
e.preventDefault();
dispatch(generateImage())
}
};
return (
<Textarea
id="prompt"
name="prompt"
resize="none"
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
value={prompt}
placeholder="I'm dreaming of..."
/>
);
};
export default PromptInput;

View File

@ -0,0 +1,74 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, OptionsState } from '../options/optionsSlice';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
steps: options.steps,
cfgScale: options.cfgScale,
sampler: options.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Sampler options. Includes steps, CFG scale, sampler.
*/
const SamplerOptions = () => {
const dispatch = useAppDispatch();
const { steps, cfgScale, sampler } = useAppSelector(optionsSelector);
const handleChangeSteps = (v: string | number) =>
dispatch(setSteps(Number(v)));
const handleChangeCfgScale = (v: string | number) =>
dispatch(setCfgScale(Number(v)));
const handleChangeSampler = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setSampler(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Steps"
min={1}
step={1}
precision={0}
onChange={handleChangeSteps}
value={steps}
/>
<SDNumberInput
label="CFG scale"
step={0.5}
onChange={handleChangeCfgScale}
value={cfgScale}
/>
<SDSelect
label="Sampler"
value={sampler}
onChange={handleChangeSampler}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -0,0 +1,159 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import {
OptionsState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariationAmount,
} from './optionsSlice';
const optionsSelector = createSelector(
(state: RootState) => state.options,
(options: OptionsState) => {
return {
variationAmount: options.variationAmount,
seedWeights: options.seedWeights,
shouldGenerateVariations: options.shouldGenerateVariations,
shouldRandomizeSeed: options.shouldRandomizeSeed,
seed: options.seed,
iterations: options.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Seed & variation options. Includes iteration, seed, seed randomization, variation options.
*/
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variationAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(optionsSelector);
const dispatch = useAppDispatch();
const handleChangeIterations = (v: string | number) =>
dispatch(setIterations(Number(v)));
const handleChangeShouldRandomizeSeed = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldRandomizeSeed(e.target.checked));
const handleChangeSeed = (v: string | number) => dispatch(setSeed(Number(v)));
const handleClickRandomizeSeed = () =>
dispatch(setSeed(randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)));
const handleChangeShouldGenerateVariations = (
e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldGenerateVariations(e.target.checked));
const handleChangevariationAmount = (v: string | number) =>
dispatch(setVariationAmount(Number(v)));
const handleChangeSeedWeights = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setSeedWeights(e.target.value));
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label="Images to generate"
step={1}
min={1}
precision={0}
onChange={handleChangeIterations}
value={iterations}
/>
<SDSwitch
label="Randomize seed on generation"
isChecked={shouldRandomizeSeed}
onChange={handleChangeShouldRandomizeSeed}
/>
<Flex gap={2}>
<SDNumberInput
label="Seed"
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={handleClickRandomizeSeed}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label="Generate variations"
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={handleChangeShouldGenerateVariations}
/>
<SDNumberInput
label="Variation amount"
value={variationAmount}
step={0.01}
min={0}
max={1}
onChange={handleChangevariationAmount}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace="nowrap">Seed Weights</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={handleChangeSeedWeights}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -1,24 +1,15 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice';
import randomInt from './util/randomInt';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import * as InvokeAI from '../../app/invokeai';
import promptToString from '../../common/util/promptToString';
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
const calculateRealSteps = (
steps: number,
strength: number,
hasInitImage: boolean
): number => {
return hasInitImage ? Math.floor(strength * steps) : steps;
};
export type UpscalingLevel = 2 | 4;
export type UpscalingLevel = 0 | 2 | 3 | 4;
export interface SDState {
export interface OptionsState {
prompt: string;
iterations: number;
steps: number;
realSteps: number;
cfgScale: number;
height: number;
width: number;
@ -34,18 +25,17 @@ export interface SDState {
seamless: boolean;
shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean;
variantAmount: number;
variationAmount: number;
seedWeights: string;
shouldRunESRGAN: boolean;
shouldRunGFPGAN: boolean;
shouldRandomizeSeed: boolean;
}
const initialSDState: SDState = {
const initialOptionsState: OptionsState = {
prompt: '',
iterations: 1,
steps: 50,
realSteps: 50,
cfgScale: 7.5,
height: 512,
width: 512,
@ -58,7 +48,7 @@ const initialSDState: SDState = {
maskPath: '',
shouldFitToWidthHeight: true,
shouldGenerateVariations: false,
variantAmount: 0.1,
variationAmount: 0.1,
seedWeights: '',
shouldRunESRGAN: false,
upscalingLevel: 4,
@ -68,27 +58,25 @@ const initialSDState: SDState = {
shouldRandomizeSeed: true,
};
const initialState: SDState = initialSDState;
const initialState: OptionsState = initialOptionsState;
export const sdSlice = createSlice({
name: 'sd',
export const optionsSlice = createSlice({
name: 'options',
initialState,
reducers: {
setPrompt: (state, action: PayloadAction<string>) => {
state.prompt = action.payload;
setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.prompt = newPrompt;
} else {
state.prompt = promptToString(newPrompt);
}
},
setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload;
},
setSteps: (state, action: PayloadAction<number>) => {
const { img2imgStrength, initialImagePath } = state;
const steps = action.payload;
state.steps = steps;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.steps = action.payload;
},
setCfgScale: (state, action: PayloadAction<number>) => {
state.cfgScale = action.payload;
@ -107,14 +95,7 @@ export const sdSlice = createSlice({
state.shouldRandomizeSeed = false;
},
setImg2imgStrength: (state, action: PayloadAction<number>) => {
const img2imgStrength = action.payload;
const { steps, initialImagePath } = state;
state.img2imgStrength = img2imgStrength;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
state.img2imgStrength = action.payload;
},
setGfpganStrength: (state, action: PayloadAction<number>) => {
state.gfpganStrength = action.payload;
@ -129,15 +110,9 @@ export const sdSlice = createSlice({
state.shouldUseInitImage = action.payload;
},
setInitialImagePath: (state, action: PayloadAction<string>) => {
const initialImagePath = action.payload;
const { steps, img2imgStrength } = state;
state.shouldUseInitImage = initialImagePath ? true : false;
state.initialImagePath = initialImagePath;
state.realSteps = calculateRealSteps(
steps,
img2imgStrength,
Boolean(initialImagePath)
);
const newInitialImagePath = action.payload;
state.shouldUseInitImage = newInitialImagePath ? true : false;
state.initialImagePath = newInitialImagePath;
},
setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload;
@ -151,13 +126,11 @@ export const sdSlice = createSlice({
resetSeed: (state) => {
state.seed = -1;
},
randomizeSeed: (state) => {
state.seed = randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX);
},
setParameter: (
state,
action: PayloadAction<{ key: string; value: string | number | boolean }>
) => {
// TODO: This probably needs to be refactored.
const { key, value } = action.payload;
const temp = { ...state, [key]: value };
if (key === 'seed') {
@ -171,70 +144,95 @@ export const sdSlice = createSlice({
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
state.shouldGenerateVariations = action.payload;
},
setVariantAmount: (state, action: PayloadAction<number>) => {
state.variantAmount = action.payload;
setVariationAmount: (state, action: PayloadAction<number>) => {
state.variationAmount = action.payload;
},
setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload;
},
setAllParameters: (state, action: PayloadAction<SDMetadata>) => {
setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
const {
prompt,
steps,
cfgScale,
height,
width,
type,
postprocessing,
sampler,
prompt,
seed,
img2imgStrength,
gfpganStrength,
upscalingLevel,
upscalingStrength,
initialImagePath,
maskPath,
variations,
steps,
cfg_scale,
seamless,
shouldFitToWidthHeight,
} = action.payload;
width,
height,
strength,
fit,
init_image_path,
mask_image_path,
} = action.payload.image;
// ?? = falsy values ('', 0, etc) are used
// || = falsy values not used
state.prompt = prompt ?? state.prompt;
state.steps = steps || state.steps;
state.cfgScale = cfgScale || state.cfgScale;
state.width = width || state.width;
state.height = height || state.height;
state.sampler = sampler || state.sampler;
state.seed = seed ?? state.seed;
state.seamless = seamless ?? state.seamless;
state.shouldFitToWidthHeight =
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight;
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength;
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength;
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel;
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength;
state.initialImagePath = initialImagePath ?? state.initialImagePath;
state.maskPath = maskPath ?? state.maskPath;
if (type === 'img2img') {
if (init_image_path) state.initialImagePath = init_image_path;
if (mask_image_path) state.maskPath = mask_image_path;
if (strength) state.img2imgStrength = strength;
if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.shouldUseInitImage = true;
} else {
state.shouldUseInitImage = false;
}
if (variations && variations.length > 0) {
state.seedWeights = seedWeightsToString(variations);
state.shouldGenerateVariations = true;
} else {
state.shouldGenerateVariations = false;
}
// If the image whose parameters we are using has a seed, disable randomizing the seed
if (seed) {
state.seed = seed;
state.shouldRandomizeSeed = false;
}
// if we have a gfpgan strength, enable it
state.shouldRunGFPGAN = gfpganStrength ? true : false;
let postprocessingNotDone = ['gfpgan', 'esrgan'];
if (postprocessing && postprocessing.length > 0) {
postprocessing.forEach(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
if (strength) state.gfpganStrength = strength;
state.shouldRunGFPGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'gfpgan'
);
}
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
if (scale) state.upscalingLevel = scale;
if (strength) state.upscalingStrength = strength;
state.shouldRunESRGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'esrgan'
);
}
}
);
}
// if we have a esrgan strength, enable it
state.shouldRunESRGAN = upscalingLevel ? true : false;
postprocessingNotDone.forEach((p) => {
if (p === 'esrgan') state.shouldRunESRGAN = false;
if (p === 'gfpgan') state.shouldRunGFPGAN = false;
});
// if we want to recreate an image exactly, we disable variations
state.shouldGenerateVariations = false;
state.shouldUseInitImage = initialImagePath ? true : false;
if (prompt) state.prompt = promptToString(prompt);
if (sampler) state.sampler = sampler;
if (steps) state.steps = steps;
if (cfg_scale) state.cfgScale = cfg_scale;
if (typeof seamless === 'boolean') state.seamless = seamless;
if (width) state.width = width;
if (height) state.height = height;
},
resetSDState: (state) => {
resetOptionsState: (state) => {
return {
...state,
...initialSDState,
...initialOptionsState,
};
},
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
@ -267,17 +265,16 @@ export const {
setInitialImagePath,
setMaskPath,
resetSeed,
randomizeSeed,
resetSDState,
resetOptionsState,
setShouldFitToWidthHeight,
setParameter,
setShouldGenerateVariations,
setSeedWeights,
setVariantAmount,
setVariationAmount,
setAllParameters,
setShouldRunGFPGAN,
setShouldRunESRGAN,
setShouldRandomizeSeed,
} = sdSlice.actions;
} = optionsSlice.actions;
export default sdSlice.reducer;
export default optionsSlice.reducer;

View File

@ -1,84 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
SDState,
} from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { UPSCALING_LEVELS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
upscalingLevel: sd.upscalingLevel,
upscalingStrength: sd.upscalingStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isESRGANAvailable: system.isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ESRGANOptions = () => {
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDSelect
isDisabled={!isESRGANAvailable}
label='Scale'
value={upscalingLevel}
onChange={(e) =>
dispatch(
setUpscalingLevel(
Number(e.target.value) as UpscalingLevel
)
)
}
validValues={UPSCALING_LEVELS}
/>
<SDNumberInput
isDisabled={!isESRGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setUpscalingStrength(Number(v)))}
value={upscalingStrength}
/>
</Flex>
);
};
export default ESRGANOptions;

View File

@ -1,63 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { SDState, setGfpganStrength } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
gfpganStrength: sd.gfpganStrength,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const GFPGANOptions = () => {
const { gfpganStrength } = useAppSelector(sdSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!isGFPGANAvailable}
label='Strength'
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setGfpganStrength(Number(v)))}
value={gfpganStrength}
/>
</Flex>
);
};
export default GFPGANOptions;

View File

@ -1,54 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import InitImage from './InitImage';
import {
SDState,
setImg2imgStrength,
setShouldFitToWidthHeight,
} from './sdSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
img2imgStrength: sd.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight,
};
}
);
const ImageToImageOptions = () => {
const { initialImagePath, img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex direction={'column'} gap={2}>
<SDNumberInput
isDisabled={!initialImagePath}
label='Strength'
step={0.01}
min={0}
max={1}
onChange={(v) => dispatch(setImg2imgStrength(Number(v)))}
value={img2imgStrength}
/>
<SDSwitch
isDisabled={!initialImagePath}
label='Fit initial image to output size'
isChecked={shouldFitToWidthHeight}
onChange={(e) =>
dispatch(setShouldFitToWidthHeight(e.target.checked))
}
/>
<InitImage />
</Flex>
);
};
export default ImageToImageOptions;

View File

@ -1,155 +0,0 @@
import {
Button,
Flex,
IconButton,
Image,
useToast,
} from '@chakra-ui/react';
import { SyntheticEvent, useCallback, useState } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { FaTrash } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import {
SDState,
setInitialImagePath,
setMaskPath,
} from '../../features/sd/sdSlice';
import MaskUploader from './MaskUploader';
import './InitImage.css';
import { uploadInitialImage } from '../../app/socketio';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
maskPath: sd.maskPath,
};
},
{ memoizeOptions: { resultEqualityCheck: isEqual } }
);
const InitImage = () => {
const toast = useToast();
const dispatch = useAppDispatch();
const { initialImagePath, maskPath } = useAppSelector(sdSelector);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadInitialImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(''));
dispatch(setMaskPath(''));
};
const handleMouseOverInitialImageUploadButton = () =>
setShouldShowMask(false);
const handleMouseOutInitialImageUploadButton = () => setShouldShowMask(true);
const handleMouseOverMaskUploadButton = () => setShouldShowMask(true);
const handleMouseOutMaskUploadButton = () => setShouldShowMask(true);
return (
<Flex
{...getRootProps({
onClick: initialImagePath ? (e) => e.stopPropagation() : undefined,
})}
direction={'column'}
alignItems={'center'}
gap={2}
>
<input {...getInputProps({ multiple: false })} />
<Flex gap={2} justifyContent={'space-between'} width={'100%'}>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton}
>
Upload Image
</Button>
<MaskUploader>
<Button
size={'sm'}
fontSize={'md'}
fontWeight={'normal'}
onClick={handleClickUploadIcon}
onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton}
>
Upload Mask
</Button>
</MaskUploader>
<IconButton
size={'sm'}
aria-label={'Reset initial image and mask'}
onClick={handleClickResetInitialImageAndMask}
icon={<FaTrash />}
/>
</Flex>
{initialImagePath && (
<Flex position={'relative'} width={'100%'}>
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
className={'checkerboard'}
/>
{shouldShowMask && maskPath && (
<Image
position={'absolute'}
top={0}
left={0}
fit={'contain'}
src={maskPath}
rounded={'md'}
zIndex={1}
className={'checkerboard'}
/>
)}
</Flex>
)}
</Flex>
);
};
export default InitImage;

View File

@ -1,61 +0,0 @@
import { useToast } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useAppDispatch } from '../../app/hooks';
import { uploadMaskImage } from '../../app/socketio';
type Props = {
children: ReactElement;
};
const MaskUploader = ({ children }: Props) => {
const dispatch = useAppDispatch();
const toast = useToast();
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
fileRejections.forEach((rejection: FileRejection) => {
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) =>
acc + '\n' + cur.message,
''
);
toast({
title: 'Upload failed',
description: msg,
status: 'error',
isClosable: true,
});
});
acceptedFiles.forEach((file: File) => {
dispatch(uploadMaskImage(file));
});
},
[dispatch, toast]
);
const { getRootProps, getInputProps, open } = useDropzone({
onDrop,
accept: {
'image/jpeg': ['.jpg', '.jpeg', '.png'],
},
});
const handleClickUploadIcon = (e: SyntheticEvent) => {
e.stopPropagation();
open();
};
return (
<div {...getRootProps()}>
<input {...getInputProps({ multiple: false })} />
{cloneElement(children, {
onClick: handleClickUploadIcon,
})}
</div>
);
};
export default MaskUploader;

View File

@ -1,211 +0,0 @@
import {
Flex,
Box,
Text,
Accordion,
AccordionItem,
AccordionButton,
AccordionIcon,
AccordionPanel,
Switch,
} from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import {
setShouldRunGFPGAN,
setShouldRunESRGAN,
SDState,
setShouldUseInitImage,
} from '../sd/sdSlice';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice';
import SeedVariationOptions from './SeedVariationOptions';
import SamplerOptions from './SamplerOptions';
import ESRGANOptions from './ESRGANOptions';
import GFPGANOptions from './GFPGANOptions';
import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
initialImagePath: sd.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isGFPGANAvailable: system.isGFPGANAvailable,
isESRGANAvailable: system.isESRGANAvailable,
openAccordions: system.openAccordions,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OptionsAccordion = () => {
const {
shouldRunESRGAN,
shouldRunGFPGAN,
shouldUseInitImage,
initialImagePath,
} = useAppSelector(sdSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
return (
<Accordion
defaultIndex={openAccordions}
allowMultiple
reduceMotion
onChange={(openAccordions) =>
dispatch(setOpenAccordions(openAccordions))
}
>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Seed & Variation
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SeedVariationOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Sampler
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<SamplerOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Upscale (ESRGAN)</Text>
<Switch
isDisabled={!isESRGANAvailable}
isChecked={shouldRunESRGAN}
onChange={(e) =>
dispatch(
setShouldRunESRGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ESRGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Fix Faces (GFPGAN)</Text>
<Switch
isDisabled={!isGFPGANAvailable}
isChecked={shouldRunGFPGAN}
onChange={(e) =>
dispatch(
setShouldRunGFPGAN(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<GFPGANOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Flex
justifyContent={'space-between'}
alignItems={'center'}
width={'100%'}
mr={2}
>
<Text>Image to Image</Text>
<Switch
isDisabled={!initialImagePath}
isChecked={shouldUseInitImage}
onChange={(e) =>
dispatch(
setShouldUseInitImage(e.target.checked)
)
}
/>
</Flex>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<ImageToImageOptions />
</AccordionPanel>
</AccordionItem>
<AccordionItem>
<h2>
<AccordionButton>
<Box flex='1' textAlign='left'>
Output
</Box>
<AccordionIcon />
</AccordionButton>
</h2>
<AccordionPanel>
<OutputOptions />
</AccordionPanel>
</AccordionItem>
</Accordion>
);
};
export default OptionsAccordion;

View File

@ -1,66 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice';
import SDSelect from '../../components/SDSelect';
import { HEIGHTS, WIDTHS } from '../../app/constants';
import SDSwitch from '../../components/SDSwitch';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
height: sd.height,
width: sd.width,
seamless: sd.seamless,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const OutputOptions = () => {
const { height, width, seamless } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<Flex gap={2}>
<SDSelect
label='Width'
value={width}
flexGrow={1}
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
validValues={WIDTHS}
/>
<SDSelect
label='Height'
value={height}
flexGrow={1}
onChange={(e) =>
dispatch(setHeight(Number(e.target.value)))
}
validValues={HEIGHTS}
/>
</Flex>
<SDSwitch
label='Seamless tiling'
fontSize={'md'}
isChecked={seamless}
onChange={(e) => dispatch(setSeamless(e.target.checked))}
/>
</Flex>
);
};
export default OutputOptions;

View File

@ -1,58 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { cancelProcessing, generateImage } from '../../app/socketio';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { SystemState } from '../system/systemSlice';
import useCheckParameters from '../system/useCheckParameters';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ProcessButtons = () => {
const { isProcessing, isConnected } = useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const isReady = useCheckParameters();
return (
<Flex gap={2} direction={'column'} alignItems={'space-between'} height={'100%'}>
<SDButton
label='Generate'
type='submit'
colorScheme='green'
flexGrow={1}
isDisabled={!isReady}
fontSize={'md'}
size={'md'}
onClick={() => dispatch(generateImage())}
/>
<SDButton
label='Cancel'
colorScheme='red'
flexGrow={1}
fontSize={'md'}
size={'md'}
isDisabled={!isConnected || !isProcessing}
onClick={() => dispatch(cancelProcessing())}
/>
</Flex>
);
};
export default ProcessButtons;

View File

@ -1,25 +0,0 @@
import { Textarea } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice';
const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.sd);
const dispatch = useAppDispatch();
return (
<Textarea
id='prompt'
name='prompt'
resize='none'
size={'lg'}
height={'100%'}
isInvalid={!prompt.length}
onChange={(e) => dispatch(setPrompt(e.target.value))}
value={prompt}
placeholder="I'm dreaming of..."
/>
);
};
export default PromptInput;

View File

@ -1,51 +0,0 @@
import {
Slider,
SliderTrack,
SliderFilledTrack,
SliderThumb,
FormControl,
FormLabel,
Text,
Flex,
SliderProps,
} from '@chakra-ui/react';
interface Props extends SliderProps {
label: string;
value: number;
fontSize?: number | string;
}
const SDSlider = ({
label,
value,
fontSize = 'sm',
onChange,
...rest
}: Props) => {
return (
<FormControl>
<Flex gap={2}>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={fontSize} whiteSpace='nowrap'>
{label}
</Text>
</FormLabel>
<Slider
aria-label={label}
focusThumbOnChange={true}
value={value}
onChange={onChange}
{...rest}
>
<SliderTrack>
<SliderFilledTrack />
</SliderTrack>
<SliderThumb />
</Slider>
</Flex>
</FormControl>
);
};
export default SDSlider;

View File

@ -1,62 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice';
import SDNumberInput from '../../components/SDNumberInput';
import SDSelect from '../../components/SDSelect';
import { SAMPLERS } from '../../app/constants';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
steps: sd.steps,
cfgScale: sd.cfgScale,
sampler: sd.sampler,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SamplerOptions = () => {
const { steps, cfgScale, sampler } = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Steps'
min={1}
step={1}
precision={0}
onChange={(v) => dispatch(setSteps(Number(v)))}
value={steps}
/>
<SDNumberInput
label='CFG scale'
step={0.5}
onChange={(v) => dispatch(setCfgScale(Number(v)))}
value={cfgScale}
/>
<SDSelect
label='Sampler'
value={sampler}
onChange={(e) => dispatch(setSampler(e.target.value))}
validValues={SAMPLERS}
/>
</Flex>
);
};
export default SamplerOptions;

View File

@ -1,144 +0,0 @@
import {
Flex,
Input,
HStack,
FormControl,
FormLabel,
Text,
Button,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
randomizeSeed,
SDState,
setIterations,
setSeed,
setSeedWeights,
setShouldGenerateVariations,
setShouldRandomizeSeed,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed,
seed: sd.seed,
iterations: sd.iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const SeedVariationOptions = () => {
const {
shouldGenerateVariations,
variantAmount,
seedWeights,
shouldRandomizeSeed,
seed,
iterations,
} = useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} direction={'column'}>
<SDNumberInput
label='Images to generate'
step={1}
min={1}
precision={0}
onChange={(v) => dispatch(setIterations(Number(v)))}
value={iterations}
/>
<SDSwitch
label='Randomize seed on generation'
isChecked={shouldRandomizeSeed}
onChange={(e) =>
dispatch(setShouldRandomizeSeed(e.target.checked))
}
/>
<Flex gap={2}>
<SDNumberInput
label='Seed'
step={1}
precision={0}
flexGrow={1}
min={NUMPY_RAND_MIN}
max={NUMPY_RAND_MAX}
isDisabled={shouldRandomizeSeed}
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={(v) => dispatch(setSeed(Number(v)))}
value={seed}
/>
<Button
size={'sm'}
isDisabled={shouldRandomizeSeed}
onClick={() => dispatch(randomizeSeed())}
>
<Text pl={2} pr={2}>
Shuffle
</Text>
</Button>
</Flex>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Variation amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default SeedVariationOptions;

View File

@ -1,92 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
HStack,
Input,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import SDNumberInput from '../../components/SDNumberInput';
import SDSwitch from '../../components/SDSwitch';
import {
SDState,
setSeedWeights,
setShouldGenerateVariations,
setVariantAmount,
} from './sdSlice';
import { validateSeedWeights } from './util/seedWeightPairs';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
variantAmount: sd.variantAmount,
seedWeights: sd.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const Variant = () => {
const { shouldGenerateVariations, variantAmount, seedWeights } =
useAppSelector(sdSelector);
const dispatch = useAppDispatch();
return (
<Flex gap={2} alignItems={'center'} pl={1}>
<SDSwitch
label='Generate variations'
isChecked={shouldGenerateVariations}
width={'auto'}
onChange={(e) =>
dispatch(setShouldGenerateVariations(e.target.checked))
}
/>
<SDNumberInput
label='Amount'
value={variantAmount}
step={0.01}
min={0}
max={1}
width={240}
isDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariantAmount(Number(v)))}
/>
<FormControl
isInvalid={
shouldGenerateVariations &&
!(validateSeedWeights(seedWeights) || seedWeights === '')
}
flexGrow={1}
isDisabled={!shouldGenerateVariations}
>
<HStack>
<FormLabel marginInlineEnd={0} marginBottom={1}>
<Text fontSize={'sm'} whiteSpace='nowrap'>
Seed Weights
</Text>
</FormLabel>
<Input
size={'sm'}
value={seedWeights}
onChange={(e) =>
dispatch(setSeedWeights(e.target.value))
}
/>
</HStack>
</FormControl>
</Flex>
);
};
export default Variant;

View File

@ -1,56 +0,0 @@
export interface SeedWeightPair {
seed: number;
weight: number;
}
export type SeedWeights = Array<Array<number>>;
export const stringToSeedWeights = (string: string): SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p) => {
return [parseInt(p[0]), parseFloat(p[1])];
});
if (!validateSeedWeights(pairs)) {
return false;
}
return pairs;
};
export const validateSeedWeights = (
seedWeights: SeedWeights | string
): boolean => {
return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights))
: Boolean(
seedWeights.length &&
!seedWeights.some((pair) => {
const [seed, weight] = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 &&
weight <= 1;
return !(isSeedValid && isWeightValid);
})
);
};
export const seedWeightsToString = (
seedWeights: SeedWeights
): string | boolean => {
if (!validateSeedWeights(seedWeights)) {
return false;
}
return seedWeights.reduce((acc, pair, i, arr) => {
const [seed, weight] = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
}
return acc;
}, '');
};

View File

@ -1,11 +1,11 @@
import {
IconButton,
useColorModeValue,
Flex,
Text,
Tooltip,
IconButton,
useColorModeValue,
Flex,
Text,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { setShouldShowLogViewer, SystemState } from './systemSlice';
import { useLayoutEffect, useRef, useState } from 'react';
@ -14,112 +14,138 @@ import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
const logSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => system.log,
{
memoizeOptions: {
resultEqualityCheck: (a, b) => a.length === b.length,
},
}
(state: RootState) => state.system,
(system: SystemState) => system.log,
{
memoizeOptions: {
// We don't need a deep equality check for this selector.
resultEqualityCheck: (a, b) => a.length === b.length,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer };
(state: RootState) => state.system,
(system: SystemState) => {
return { shouldShowLogViewer: system.shouldShowLogViewer };
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
}
);
/**
* Basic log viewer, floats on bottom of page.
*/
const LogViewer = () => {
const dispatch = useAppDispatch();
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const dispatch = useAppDispatch();
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
const log = useAppSelector(logSelector);
const { shouldShowLogViewer } = useAppSelector(systemSelector);
// Set colors based on dark/light mode
const bg = useColorModeValue('gray.50', 'gray.900');
const borderColor = useColorModeValue('gray.500', 'gray.500');
const logTextColors = useColorModeValue(
{
info: undefined,
warning: 'yellow.500',
error: 'red.500',
},
{
info: undefined,
warning: 'yellow.300',
error: 'red.300',
}
);
const viewerRef = useRef<HTMLDivElement>(null);
// Rudimentary autoscroll
const [shouldAutoscroll, setShouldAutoscroll] = useState<boolean>(true);
const viewerRef = useRef<HTMLDivElement>(null);
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
});
/**
* If autoscroll is on, scroll to the bottom when:
* - log updates
* - viewer is toggled
*
* Also scroll to the bottom whenever autoscroll is turned on.
*/
useLayoutEffect(() => {
if (viewerRef.current !== null && shouldAutoscroll) {
viewerRef.current.scrollTop = viewerRef.current.scrollHeight;
}
}, [shouldAutoscroll, log, shouldShowLogViewer]);
return (
<>
{shouldShowLogViewer && (
<Flex
position={'fixed'}
left={0}
bottom={0}
height='200px'
width='100vw'
overflow='auto'
direction='column'
fontFamily='monospace'
fontSize='sm'
pl={12}
pr={2}
pb={2}
borderTopWidth='4px'
borderColor={borderColor}
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => (
<Flex gap={2} key={i}>
<Text fontSize='sm' fontWeight={'semibold'}>
{entry.timestamp}:
</Text>
<Text fontSize='sm' wordBreak={'break-all'}>
{entry.message}
</Text>
</Flex>
))}
</Flex>
)}
{shouldShowLogViewer && (
<Tooltip
label={
shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'
}
>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={12}
aria-label='Toggle autoscroll'
variant={'solid'}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
icon={<FaAngleDoubleDown />}
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
/>
</Tooltip>
)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
<IconButton
size='sm'
position={'fixed'}
left={2}
bottom={2}
variant={'solid'}
aria-label='Toggle Log Viewer'
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={() =>
dispatch(setShouldShowLogViewer(!shouldShowLogViewer))
}
/>
</Tooltip>
</>
);
const handleClickLogViewerToggle = () => {
dispatch(setShouldShowLogViewer(!shouldShowLogViewer));
};
return (
<>
{shouldShowLogViewer && (
<Flex
position={'fixed'}
left={0}
bottom={0}
height="200px" // TODO: Make the log viewer resizeable.
width="100vw"
overflow="auto"
direction="column"
fontFamily="monospace"
fontSize="sm"
pl={12}
pr={2}
pb={2}
borderTopWidth="4px"
borderColor={borderColor}
background={bg}
ref={viewerRef}
>
{log.map((entry, i) => {
const { timestamp, message, level } = entry;
return (
<Flex gap={2} key={i} textColor={logTextColors[level]}>
<Text fontSize="sm" fontWeight={'semibold'}>
{timestamp}:
</Text>
<Text fontSize="sm" wordBreak={'break-all'}>
{message}
</Text>
</Flex>
);
})}
</Flex>
)}
{shouldShowLogViewer && (
<Tooltip label={shouldAutoscroll ? 'Autoscroll on' : 'Autoscroll off'}>
<IconButton
size="sm"
position={'fixed'}
left={2}
bottom={12}
aria-label="Toggle autoscroll"
variant={'solid'}
colorScheme={shouldAutoscroll ? 'blue' : 'gray'}
icon={<FaAngleDoubleDown />}
onClick={() => setShouldAutoscroll(!shouldAutoscroll)}
/>
</Tooltip>
)}
<Tooltip label={shouldShowLogViewer ? 'Hide logs' : 'Show logs'}>
<IconButton
size="sm"
position={'fixed'}
left={2}
bottom={2}
variant={'solid'}
aria-label="Toggle Log Viewer"
icon={shouldShowLogViewer ? <FaMinus /> : <FaCode />}
onClick={handleClickLogViewerToggle}
/>
</Tooltip>
</>
);
};
export default LogViewer;

View File

@ -0,0 +1,38 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
currentStep: system.currentStep,
totalSteps: system.totalSteps,
currentStatusHasSteps: system.currentStatusHasSteps,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
const ProgressBar = () => {
const { isProcessing, currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(systemSelector);
const value = currentStep ? Math.round((currentStep * 100) / totalSteps) : 0;
return (
<Progress
height="10px"
value={value}
isIndeterminate={isProcessing && !currentStatusHasSteps}
/>
);
};
export default ProgressBar;

View File

@ -1,170 +1,164 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Switch,
Text,
useDisclosure,
Button,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Switch,
Text,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from '../../app/hooks';
import { useAppDispatch, useAppSelector } from '../../app/store';
import {
setShouldConfirmOnDelete,
setShouldDisplayInProgress,
SystemState,
setShouldConfirmOnDelete,
setShouldDisplayInProgress,
SystemState,
} from './systemSlice';
import { RootState } from '../../app/store';
import SDButton from '../../components/SDButton';
import { persistor } from '../../main';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { cloneElement, ReactElement } from 'react';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
(state: RootState) => state.system,
(system: SystemState) => {
const { shouldDisplayInProgress, shouldConfirmOnDelete } = system;
return { shouldDisplayInProgress, shouldConfirmOnDelete };
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
type Props = {
children: ReactElement;
type SettingsModalProps = {
/* The button to open the Settings Modal */
children: ReactElement;
};
const SettingsModal = ({ children }: Props) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
/**
* Modal for app settings. Also provides Reset functionality in which the
* app's localstorage is wiped via redux-persist.
*
* Secondary post-reset modal is included here.
*/
const SettingsModal = ({ children }: SettingsModalProps) => {
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
onClose: onSettingsModalClose,
} = useDisclosure();
const {
isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose,
} = useDisclosure();
const {
isOpen: isRefreshModalOpen,
onOpen: onRefreshModalOpen,
onClose: onRefreshModalClose,
} = useDisclosure();
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector);
const { shouldDisplayInProgress, shouldConfirmOnDelete } =
useAppSelector(systemSelector);
const dispatch = useAppDispatch();
const dispatch = useAppDispatch();
const handleClickResetWebUI = () => {
persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
/**
* Resets localstorage, then opens a secondary modal informing user to
* refresh their browser.
* */
const handleClickResetWebUI = () => {
persistor.purge().then(() => {
onSettingsModalClose();
onRefreshModalOpen();
});
};
return (
<>
{cloneElement(children, {
onClick: onSettingsModalOpen,
})}
return (
<>
{cloneElement(children, {
onClick: onSettingsModalOpen,
})}
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex gap={5} direction='column'>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Display in-progress images (slower)
</FormLabel>
<Switch
isChecked={shouldDisplayInProgress}
onChange={(e) =>
dispatch(
setShouldDisplayInProgress(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Confirm on delete
</FormLabel>
<Switch
isChecked={shouldConfirmOnDelete}
onChange={(e) =>
dispatch(
setShouldConfirmOnDelete(
e.target.checked
)
)
}
/>
</HStack>
</FormControl>
<Modal isOpen={isSettingsModalOpen} onClose={onSettingsModalClose}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Settings</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Flex gap={5} direction="column">
<FormControl>
<HStack>
<FormLabel marginBottom={1}>
Display in-progress images (slower)
</FormLabel>
<Switch
isChecked={shouldDisplayInProgress}
onChange={(e) =>
dispatch(setShouldDisplayInProgress(e.target.checked))
}
/>
</HStack>
</FormControl>
<FormControl>
<HStack>
<FormLabel marginBottom={1}>Confirm on delete</FormLabel>
<Switch
isChecked={shouldConfirmOnDelete}
onChange={(e) =>
dispatch(setShouldConfirmOnDelete(e.target.checked))
}
/>
</HStack>
</FormControl>
<Heading size={'md'}>Reset Web UI</Heading>
<Text>
Resetting the web UI only resets the browser's
local cache of your images and remembered
settings. It does not delete any images from
disk.
</Text>
<Text>
If images aren't showing up in the gallery or
something else isn't working, please try
resetting before submitting an issue on GitHub.
</Text>
<SDButton
label='Reset Web UI'
colorScheme='red'
onClick={handleClickResetWebUI}
/>
</Flex>
</ModalBody>
<Heading size={'md'}>Reset Web UI</Heading>
<Text>
Resetting the web UI only resets the browser's local cache of
your images and remembered settings. It does not delete any
images from disk.
</Text>
<Text>
If images aren't showing up in the gallery or something else
isn't working, please try resetting before submitting an issue
on GitHub.
</Text>
<Button colorScheme="red" onClick={handleClickResetWebUI}>
Reset Web UI
</Button>
</Flex>
</ModalBody>
<ModalFooter>
<SDButton
label='Close'
onClick={onSettingsModalClose}
/>
</ModalFooter>
</ModalContent>
</Modal>
<ModalFooter>
<Button onClick={onSettingsModalClose}>Close</Button>
</ModalFooter>
</ModalContent>
</Modal>
<Modal
closeOnOverlayClick={false}
isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose}
isCentered
>
<ModalOverlay bg='blackAlpha.300' backdropFilter='blur(40px)' />
<ModalContent>
<ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}>
<Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to
reload.
</Text>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
<Modal
closeOnOverlayClick={false}
isOpen={isRefreshModalOpen}
onClose={onRefreshModalClose}
isCentered
>
<ModalOverlay bg="blackAlpha.300" backdropFilter="blur(40px)" />
<ModalContent>
<ModalBody pb={6} pt={6}>
<Flex justifyContent={'center'}>
<Text fontSize={'lg'}>
Web UI has been reset. Refresh the page to reload.
</Text>
</Flex>
</ModalBody>
</ModalContent>
</Modal>
</>
);
};
export default SettingsModal;

View File

@ -0,0 +1,120 @@
import {
Flex,
Heading,
IconButton,
Link,
Spacer,
Text,
useColorMode,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { FaSun, FaMoon, FaGithub } from 'react-icons/fa';
import { MdHelp, MdSettings } from 'react-icons/md';
import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store';
import SettingsModal from '../system/SettingsModal';
import { SystemState } from '../system/systemSlice';
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isConnected: system.isConnected,
isProcessing: system.isProcessing,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
currentStatus: system.currentStatus,
};
},
{
memoizeOptions: { resultEqualityCheck: isEqual },
}
);
/**
* Header, includes color mode toggle, settings button, status message.
*/
const SiteHeader = () => {
const { colorMode, toggleColorMode } = useColorMode();
const {
isConnected,
isProcessing,
currentIteration,
totalIterations,
currentStatus,
} = useAppSelector(systemSelector);
const statusMessageTextColor = isConnected ? 'green.500' : 'red.500';
const colorModeIcon = colorMode == 'light' ? <FaMoon /> : <FaSun />;
// Make FaMoon and FaSun icon apparent size consistent
const colorModeIconFontSize = colorMode == 'light' ? 18 : 20;
let statusMessage = currentStatus;
if (isProcessing) {
if (totalIterations > 1) {
statusMessage += ` [${currentIteration}/${totalIterations}]`;
}
}
return (
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>InvokeUI</Heading>
<Spacer />
<Text textColor={statusMessageTextColor}>{statusMessage}</Text>
<SettingsModal>
<IconButton
aria-label="Settings"
variant="link"
fontSize={24}
size={'sm'}
icon={<MdSettings />}
/>
</SettingsModal>
<IconButton
aria-label="Link to Github Issues"
variant="link"
fontSize={23}
size={'sm'}
icon={
<Link
isExternal
href="http://github.com/lstein/stable-diffusion/issues"
>
<MdHelp />
</Link>
}
/>
<IconButton
aria-label="Link to Github Repo"
variant="link"
fontSize={20}
size={'sm'}
icon={
<Link isExternal href="http://github.com/lstein/stable-diffusion">
<FaGithub />
</Link>
}
/>
<IconButton
aria-label="Toggle Dark Mode"
onClick={toggleColorMode}
variant="link"
size={'sm'}
fontSize={colorModeIconFontSize}
icon={colorModeIcon}
/>
</Flex>
);
};
export default SiteHeader;

View File

@ -1,10 +1,13 @@
import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import { ExpandedIndex } from '@chakra-ui/react';
import * as InvokeAI from '../../app/invokeai'
export type LogLevel = 'info' | 'warning' | 'error';
export interface LogEntry {
timestamp: string;
level: LogLevel;
message: string;
}
@ -12,10 +15,8 @@ export interface Log {
[index: number]: LogEntry;
}
export interface SystemState {
export interface SystemState extends InvokeAI.SystemStatus, InvokeAI.SystemConfig {
shouldDisplayInProgress: boolean;
isProcessing: boolean;
currentStep: number;
log: Array<LogEntry>;
shouldShowLogViewer: boolean;
isGFPGANAvailable: boolean;
@ -24,12 +25,17 @@ export interface SystemState {
socketId: string;
shouldConfirmOnDelete: boolean;
openAccordions: ExpandedIndex;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
const initialSystemState = {
isConnected: false,
isProcessing: false,
currentStep: 0,
log: [],
shouldShowLogViewer: false,
shouldDisplayInProgress: false,
@ -38,6 +44,17 @@ const initialSystemState = {
socketId: '',
shouldConfirmOnDelete: true,
openAccordions: [0],
currentStep: 0,
totalSteps: 0,
currentIteration: 0,
totalIterations: 0,
currentStatus: '',
currentStatusHasSteps: false,
model: '',
model_id: '',
model_hash: '',
app_id: '',
app_version: '',
};
const initialState: SystemState = initialSystemState;
@ -51,18 +68,35 @@ export const systemSlice = createSlice({
},
setIsProcessing: (state, action: PayloadAction<boolean>) => {
state.isProcessing = action.payload;
if (action.payload === false) {
state.currentStep = 0;
}
},
setCurrentStep: (state, action: PayloadAction<number>) => {
state.currentStep = action.payload;
setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStatus = action.payload;
},
addLogEntry: (state, action: PayloadAction<string>) => {
setSystemStatus: (state, action: PayloadAction<InvokeAI.SystemStatus>) => {
const currentStatus =
!action.payload.isProcessing && state.isConnected
? 'Connected'
: action.payload.currentStatus;
return { ...state, ...action.payload, currentStatus };
},
addLogEntry: (
state,
action: PayloadAction<{
timestamp: string;
message: string;
level?: LogLevel;
}>
) => {
const { timestamp, message, level } = action.payload;
const logLevel = level || 'info';
const entry: LogEntry = {
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: action.payload,
timestamp,
message,
level: logLevel,
};
state.log.push(entry);
},
setShouldShowLogViewer: (state, action: PayloadAction<boolean>) => {
@ -80,19 +114,24 @@ export const systemSlice = createSlice({
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
state.openAccordions = action.payload;
},
setSystemConfig: (state, action: PayloadAction<InvokeAI.SystemConfig>) => {
return { ...state, ...action.payload };
},
},
});
export const {
setShouldDisplayInProgress,
setIsProcessing,
setCurrentStep,
addLogEntry,
setShouldShowLogViewer,
setIsConnected,
setSocketId,
setShouldConfirmOnDelete,
setOpenAccordions,
setSystemStatus,
setCurrentStatus,
setSystemConfig,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -1,108 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import { useMemo } from 'react';
import { useAppSelector } from '../../app/hooks';
import { RootState } from '../../app/store';
import { SDState } from '../sd/sdSlice';
import { validateSeedWeights } from '../sd/util/seedWeightPairs';
import { SystemState } from './systemSlice';
const sdSelector = createSelector(
(state: RootState) => state.sd,
(sd: SDState) => {
return {
prompt: sd.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations,
seedWeights: sd.seedWeights,
maskPath: sd.maskPath,
initialImagePath: sd.initialImagePath,
seed: sd.seed,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const systemSelector = createSelector(
(state: RootState) => state.system,
(system: SystemState) => {
return {
isProcessing: system.isProcessing,
isConnected: system.isConnected,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/*
Checks relevant pieces of state to confirm generation will not deterministically fail.
This is used to prevent the 'Generate' button from being clicked.
Other parameter values may cause failure but we rely on input validation for those.
*/
const useCheckParameters = () => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
seed,
} = useAppSelector(sdSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector);
return useMemo(() => {
// Cannot generate without a prompt
if (!prompt) {
return false;
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
return false;
}
// TODO: job queue
// Cannot generate if already processing an image
if (isProcessing) {
return false;
}
// Cannot generate if not connected
if (!isConnected) {
return false;
}
// Cannot generate variations without valid seed weights
if (
shouldGenerateVariations &&
(!(validateSeedWeights(seedWeights) || seedWeights === '') ||
seed === -1)
) {
return false;
}
// All good
return true;
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
isConnected,
shouldGenerateVariations,
seedWeights,
seed,
]);
};
export default useCheckParameters;

View File

@ -8,9 +8,9 @@ import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);
import App from './App';
import { theme } from './app/theme';
import Loading from './Loading';
import App from './app/App';
ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render(
<React.StrictMode>

View File

@ -419,6 +419,13 @@
compute-scroll-into-view "1.0.14"
copy-to-clipboard "3.3.1"
"@chakra-ui/icon@3.0.10":
version "3.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.10.tgz#1a11b5edb42a8af7aa5b6dec2bf2c6c4df1869fc"
integrity sha512-utO569d9bptEraJrEhuImfNzQ8v+a8PsQh8kTsodCzg8B16R3t5TTuoqeJqS6Nq16Vq6w87QbX3/4A73CNK5fw==
dependencies:
"@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icon@3.0.9":
version "3.0.9"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744"
@ -426,6 +433,13 @@
dependencies:
"@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icons@^2.0.10":
version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icons/-/icons-2.0.10.tgz#61aeb44c913c10e7ff77addc798494e50d66c760"
integrity sha512-hxMspvysOay2NsJyadM611F/Y4vVzJU/YkXTxsyBjm6v/DbENhpVmPnUf+kwwyl7dINNb9iOF+kuGxnuIEO1Tw==
dependencies:
"@chakra-ui/icon" "3.0.10"
"@chakra-ui/image@2.0.10":
version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8"

View File

@ -100,6 +100,13 @@ SAMPLER_CHOICES = [
'plms',
]
PRECISION_CHOICES = [
'auto',
'float32',
'autocast',
'float16',
]
# is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion'
APP_VERSION = 'v1.15'
@ -174,31 +181,37 @@ class Args(object):
switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'-A {a["sampler_name"]}')
if a['grid']:
switches.append('--grid')
if a['seamless']:
switches.append('--seamless')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0:
switches.append(f'-M {a["init_mask"]}')
if a['init_color'] and len(a['init_color'])>0:
switches.append(f'--init_color {a["init_color"]}')
if a['fit']:
switches.append(f'--fit')
if a['init_img'] and a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
if a['fit']:
switches.append(f'--fit')
if a['init_mask'] and len(a['init_mask'])>0:
switches.append(f'-M {a["init_mask"]}')
if a['init_color'] and len(a['init_color'])>0:
switches.append(f'--init_color {a["init_color"]}')
if a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
else:
switches.append(f'-A {a["sampler_name"]}')
# gfpgan-specific parameters
if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}')
# esrgan-specific parameters
if a['upscale']:
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
if a['embiggen']:
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
if a['embiggen_tiles']:
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
if a['variation_amount'] > 0:
switches.append(f'-v {a["variation_amount"]}')
if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
switches.append(f'-V {formatted_variations}')
@ -316,7 +329,16 @@ class Args(object):
'--full_precision',
dest='full_precision',
action='store_true',
help='Use more memory-intensive full precision math for calculations',
help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--precision',
dest='precision',
type=str,
choices=PRECISION_CHOICES,
metavar='PRECISION',
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
file_group.add_argument(
'--from_file',
@ -618,18 +640,24 @@ def metadata_dumps(opt,
postprocessing=postprocessing
)
# TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'] = []
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
# 'postprocessing' is either null or an array of postprocessing metadatal
if postprocessing:
# TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'] = []
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
else:
image_dict['postprocessing'] = None
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','step_number','width','height','extra','strength']
rfc_dict ={}
for item in image_dict.items():
key,value = item
if key in rfc266_img_fields:
@ -637,25 +665,24 @@ def metadata_dumps(opt,
# semantic drift
rfc_dict['sampler'] = image_dict.get('sampler_name',None)
# display weighted subprompts (liable to change)
if opt.prompt:
subprompts = split_weighted_subprompts(opt.prompt)
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
rfc_dict['prompt'] = subprompts
# variations
if opt.with_variations:
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
rfc_dict['variations'] = variations
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else []
if opt.init_img:
rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else:
rfc_dict['type'] = 'txt2img'
rfc_dict.pop('strength')
if len(seeds)==0 and opt.seed:
seeds=[seed]

View File

@ -1,6 +1,6 @@
import torch
from torch import autocast
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
@ -10,15 +10,18 @@ def choose_torch_device() -> str:
return 'mps'
return 'cpu'
def choose_autocast_device(device):
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only for cuda, but GTX 16xx have issues with it
if device_type == 'cuda':
device_name = torch.cuda.get_device_name()
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
return device_type,nullcontext
else:
return device_type,autocast
else:
return 'cpu',nullcontext
def choose_precision(device) -> str:
'''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device)
if not ('GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name):
return 'float16'
return 'float32'
def choose_autocast(precision):
'''Returns an autocast context or nullcontext for the given precision string'''
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext

View File

@ -9,13 +9,14 @@ from tqdm import tqdm, trange
from PIL import Image
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.dream.devices import choose_autocast_device
from ldm.dream.devices import choose_autocast
downsampling = 8
class Generator():
def __init__(self,model):
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
@ -38,7 +39,7 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None,
**kwargs):
device_type,scope = choose_autocast_device(self.model.device)
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
prompt,
init_image = init_image,
@ -51,7 +52,7 @@ class Generator():
results = []
seed = seed if seed else self.new_seed()
seed, initial_noise = self.generate_initial_noise(seed, width, height)
with scope(device_type), self.model.ema_scope():
with scope(self.model.device.type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
x_T = None
if self.variation_amount > 0:

View File

@ -11,8 +11,8 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.generator.img2img import Img2Img
class Embiggen(Generator):
def __init__(self,model):
super().__init__(model)
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None
@torch.no_grad()

View File

@ -4,15 +4,15 @@ ldm.dream.generator.img2img descends from ldm.dream.generator
import torch
import numpy as np
from ldm.dream.devices import choose_autocast_device
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator):
def __init__(self,model):
super().__init__(model)
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,**kwargs):
@ -32,8 +32,8 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
device_type,scope = choose_autocast_device(self.model.device)
with scope(device_type):
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space

View File

@ -5,15 +5,15 @@ ldm.dream.generator.inpaint descends from ldm.dream.generator
import torch
import numpy as np
from einops import rearrange, repeat
from ldm.dream.devices import choose_autocast_device
from ldm.dream.devices import choose_autocast
from ldm.dream.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
class Inpaint(Img2Img):
def __init__(self,model):
def __init__(self, model, precision):
self.init_latent = None
super().__init__(model)
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
@ -38,8 +38,8 @@ class Inpaint(Img2Img):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
device_type,scope = choose_autocast_device(self.model.device)
with scope(device_type):
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space

View File

@ -7,9 +7,9 @@ import numpy as np
from ldm.dream.generator.base import Generator
class Txt2Img(Generator):
def __init__(self,model):
super().__init__(model)
def __init__(self, model, precision):
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,**kwargs):

61
ldm/dream/log.py Normal file
View File

@ -0,0 +1,61 @@
"""
Functions for better format logging
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
1 write_log_message -- Writes a message to the console
2 write_log_files -- Writes a message to files
2.1 write_log_default -- File in plain text
2.2 write_log_txt -- File in txt format
2.3 write_log_markdown -- File in markdown format
"""
import os
def write_log(results, log_path, file_types, output_cntr):
"""
logs the name of the output image, prompt, and prompt args to the terminal and files
"""
output_cntr = write_log_message(results, output_cntr)
write_log_files(results, log_path, file_types)
return output_cntr
def write_log_message(results, output_cntr):
"""logs to the terminal"""
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
for l in log_lines:
output_cntr += 1
print(f"[{output_cntr}] {l}", end="")
return output_cntr
def write_log_files(results, log_path, file_types):
for file_type in file_types:
if file_type == "txt":
write_log_txt(log_path, results)
elif file_type == "md" or file_type == "markdown":
write_log_markdown(log_path, results)
else:
print(f"'{file_type}' format is not supported, so write in plain text")
write_log_default(log_path, results, file_type)
def write_log_default(log_path, results, file_type):
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
file.writelines(plain_txt_lines)
def write_log_txt(log_path, results):
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
with open(log_path + ".txt", "a", encoding="utf-8") as file:
file.writelines(txt_lines)
def write_log_markdown(log_path, results):
md_lines = []
for path, prompt in results:
file_name = os.path.basename(path)
md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
with open(log_path + ".md", "a", encoding="utf-8") as file:
file.writelines(md_lines)

View File

@ -34,6 +34,7 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None):
print(f'self.outdir={self.outdir}, name={name}')
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)

View File

@ -29,7 +29,7 @@ from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device
from ldm.dream.devices import choose_torch_device, choose_precision
from ldm.dream.conditioning import get_uc_and_c
def fix_func(orig):
@ -104,7 +104,7 @@ gr = Generate(
# these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file
full_precision = False
precision = float precision to be used
# this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
@ -130,6 +130,7 @@ class Generate:
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
@ -145,7 +146,7 @@ class Generate:
self.cfg_scale = 7.5
self.sampler_name = sampler_name
self.ddim_eta = 0.0 # same seed always produces same image
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
self.precision = precision
self.strength = 0.75
self.seamless = False
self.embedding_path = embedding_path
@ -162,6 +163,14 @@ class Generate:
# it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device()
self.device = torch.device(device_type)
if full_precision:
if self.precision != 'auto':
raise ValueError('Remove --full_precision / -F if using --precision')
print('Please remove deprecated --full_precision / -F')
print('If auto config does not work you can use --precision=float32')
self.precision = 'float32'
if self.precision == 'auto':
self.precision = choose_precision(self.device)
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@ -440,25 +449,25 @@ class Generate:
def _make_img2img(self):
if not self.generators.get('img2img'):
from ldm.dream.generator.img2img import Img2Img
self.generators['img2img'] = Img2Img(self.model)
self.generators['img2img'] = Img2Img(self.model, self.precision)
return self.generators['img2img']
def _make_embiggen(self):
if not self.generators.get('embiggen'):
from ldm.dream.generator.embiggen import Embiggen
self.generators['embiggen'] = Embiggen(self.model)
self.generators['embiggen'] = Embiggen(self.model, self.precision)
return self.generators['embiggen']
def _make_txt2img(self):
if not self.generators.get('txt2img'):
from ldm.dream.generator.txt2img import Txt2Img
self.generators['txt2img'] = Txt2Img(self.model)
self.generators['txt2img'] = Txt2Img(self.model, self.precision)
return self.generators['txt2img']
def _make_inpaint(self):
if not self.generators.get('inpaint'):
from ldm.dream.generator.inpaint import Inpaint
self.generators['inpaint'] = Inpaint(self.model)
self.generators['inpaint'] = Inpaint(self.model, self.precision)
return self.generators['inpaint']
def load_model(self):
@ -469,7 +478,7 @@ class Generate:
model = self._load_model_from_config(self.config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(
self.embedding_path, self.full_precision
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
)
self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
@ -619,16 +628,13 @@ class Generate:
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
if self.full_precision:
print(
'>> Using slower but more accurate full-precision math (--full_precision)'
)
if self.precision == 'float16':
print('Using faster float16 precision')
model.to(torch.float16)
else:
print(
'>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
)
model.half()
print('Using more accurate float32 precision')
model.to(self.device)
model.eval()

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[tool.blue]
line-length = 90
target-version = ['py310']

Some files were not shown because too many files have changed in this diff Show More