Merge branch 'development' of github.com:psychedelicious/stable-diffusion into psychedelicious-development

This commit is contained in:
Lincoln Stein 2022-09-20 14:54:09 -04:00
commit 7830fd8ca1
44 changed files with 2293 additions and 2020 deletions

View File

@ -2,14 +2,14 @@ from modules.parse_seed_weights import parse_seed_weights
import argparse import argparse
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
'ddim', "ddim",
'k_dpm_2_a', "k_dpm_2_a",
'k_dpm_2', "k_dpm_2",
'k_euler_a', "k_euler_a",
'k_euler', "k_euler",
'k_heun', "k_heun",
'k_lms', "k_lms",
'plms', "plms",
] ]
@ -20,194 +20,42 @@ def parameters_to_command(params):
switches = list() switches = list()
if 'prompt' in params: if "prompt" in params:
switches.append(f'"{params["prompt"]}"') switches.append(f'"{params["prompt"]}"')
if 'steps' in params: if "steps" in params:
switches.append(f'-s {params["steps"]}') switches.append(f'-s {params["steps"]}')
if 'seed' in params: if "seed" in params:
switches.append(f'-S {params["seed"]}') switches.append(f'-S {params["seed"]}')
if 'width' in params: if "width" in params:
switches.append(f'-W {params["width"]}') switches.append(f'-W {params["width"]}')
if 'height' in params: if "height" in params:
switches.append(f'-H {params["height"]}') switches.append(f'-H {params["height"]}')
if 'cfg_scale' in params: if "cfg_scale" in params:
switches.append(f'-C {params["cfg_scale"]}') switches.append(f'-C {params["cfg_scale"]}')
if 'sampler_name' in params: if "sampler_name" in params:
switches.append(f'-A {params["sampler_name"]}') switches.append(f'-A {params["sampler_name"]}')
if 'seamless' in params and params["seamless"] == True: if "seamless" in params and params["seamless"] == True:
switches.append(f'--seamless') switches.append(f"--seamless")
if 'init_img' in params and len(params['init_img']) > 0: if "init_img" in params and len(params["init_img"]) > 0:
switches.append(f'-I {params["init_img"]}') switches.append(f'-I {params["init_img"]}')
if 'init_mask' in params and len(params['init_mask']) > 0: if "init_mask" in params and len(params["init_mask"]) > 0:
switches.append(f'-M {params["init_mask"]}') switches.append(f'-M {params["init_mask"]}')
if 'init_color' in params and len(params['init_color']) > 0: if "init_color" in params and len(params["init_color"]) > 0:
switches.append(f'--init_color {params["init_color"]}') switches.append(f'--init_color {params["init_color"]}')
if 'strength' in params and 'init_img' in params: if "strength" in params and "init_img" in params:
switches.append(f'-f {params["strength"]}') switches.append(f'-f {params["strength"]}')
if 'fit' in params and params["fit"] == True: if "fit" in params and params["fit"] == True:
switches.append(f'--fit') switches.append(f"--fit")
if 'gfpgan_strength' in params and params["gfpgan_strength"]: if "gfpgan_strength" in params and params["gfpgan_strength"]:
switches.append(f'-G {params["gfpgan_strength"]}') switches.append(f'-G {params["gfpgan_strength"]}')
if 'upscale' in params and params["upscale"]: if "upscale" in params and params["upscale"]:
switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}') switches.append(f'-U {params["upscale"][0]} {params["upscale"][1]}')
if 'variation_amount' in params and params['variation_amount'] > 0: if "variation_amount" in params and params["variation_amount"] > 0:
switches.append(f'-v {params["variation_amount"]}') switches.append(f'-v {params["variation_amount"]}')
if 'with_variations' in params: if "with_variations" in params:
seed_weight_pairs = ','.join(f'{seed}:{weight}' for seed, weight in params["with_variations"]) seed_weight_pairs = ",".join(
switches.append(f'-V {seed_weight_pairs}') f"{seed}:{weight}" for seed, weight in params["with_variations"]
)
switches.append(f"-V {seed_weight_pairs}")
return ' '.join(switches) return " ".join(switches)
def create_cmd_parser():
"""
This is simply a copy of the parser from `dream.py` with a change to give
prompt a default value. This is a temporary hack pending merge of #587 which
provides a better way to do this.
"""
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12',
exit_on_error=True,
)
parser.add_argument('prompt', nargs='?', default='')
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'--init_color',
type=str,
help='Path to reference image for color correction (used for repeated img2img and inpainting)'
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser

View File

@ -6,7 +6,6 @@ import traceback
import eventlet import eventlet
import glob import glob
import shlex import shlex
import argparse
import math import math
import shutil import shutil
@ -23,8 +22,10 @@ from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
from ldm.gfpgan.gfpgan_tools import run_gfpgan from ldm.gfpgan.gfpgan_tools import run_gfpgan
from ldm.generate import Generate from ldm.generate import Generate
from ldm.dream.pngwriter import PngWriter, retrieve_metadata from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.dream.conditioning import split_weighted_subprompts
from modules.parameters import parameters_to_command, create_cmd_parser from modules.parameters import parameters_to_command
""" """
@ -32,12 +33,14 @@ USER CONFIG
""" """
output_dir = "outputs/" # Base output directory for images output_dir = "outputs/" # Base output directory for images
#host = 'localhost' # Web & socket.io host # host = 'localhost' # Web & socket.io host
host = '0.0.0.0' # Web & socket.io host host = "localhost" # Web & socket.io host
port = 9090 # Web & socket.io port port = 9090 # Web & socket.io port
verbose = False # enables copious socket.io logging verbose = False # enables copious socket.io logging
additional_allowed_origins = ['http://localhost:9090'] # additional CORS allowed origins additional_allowed_origins = [
"http://localhost:5173"
] # additional CORS allowed origins
model = "stable-diffusion-1.4"
""" """
END USER CONFIG END USER CONFIG
@ -50,26 +53,23 @@ SERVER SETUP
# fix missing mimetypes on windows due to registry wonkiness # fix missing mimetypes on windows due to registry wonkiness
mimetypes.add_type('application/javascript', '.js') mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type('text/css', '.css') mimetypes.add_type("text/css", ".css")
app = Flask(__name__, static_url_path='', static_folder='../frontend/dist/') app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/")
app.config['OUTPUTS_FOLDER'] = "../outputs" app.config["OUTPUTS_FOLDER"] = "../outputs"
@app.route('/outputs/<path:filename>') @app.route("/outputs/<path:filename>")
def outputs(filename): def outputs(filename):
return send_from_directory( return send_from_directory(app.config["OUTPUTS_FOLDER"], filename)
app.config['OUTPUTS_FOLDER'],
filename
)
@app.route("/", defaults={'path': ''}) @app.route("/", defaults={"path": ""})
def serve(path): def serve(path):
return send_from_directory(app.static_folder, 'index.html') return send_from_directory(app.static_folder, "index.html")
logger = True if verbose else False logger = True if verbose else False
@ -81,12 +81,12 @@ max_http_buffer_size = 10000000
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
socketio = SocketIO( socketio = SocketIO(
app, app,
logger=logger, logger=logger,
engineio_logger=engineio_logger, engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size, max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins, cors_allowed_origins=cors_allowed_origins,
) )
""" """
@ -107,29 +107,31 @@ canceled = Event()
# reduce logging outputs to error # reduce logging outputs to error
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model # Initialize and load model
model = Generate() generate = Generate(model)
model.load_model() generate.load_model()
# location for "finished" images # location for "finished" images
result_path = os.path.join(output_dir, 'img-samples/') result_path = os.path.join(output_dir, "img-samples/")
# temporary path for intermediates # temporary path for intermediates
intermediate_path = os.path.join(result_path, 'intermediates/') intermediate_path = os.path.join(result_path, "intermediates/")
# path for user-uploaded init images and masks # path for user-uploaded init images and masks
init_image_path = os.path.join(result_path, 'init-images/') init_image_path = os.path.join(result_path, "init-images/")
mask_image_path = os.path.join(result_path, 'mask-images/') mask_image_path = os.path.join(result_path, "mask-images/")
# txt log # txt log
log_path = os.path.join(result_path, 'dream_log.txt') log_path = os.path.join(result_path, "dream_log.txt")
# make all output paths # make all output paths
[os.makedirs(path, exist_ok=True) [
for path in [result_path, intermediate_path, init_image_path, mask_image_path]] os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_image_path, mask_image_path]
]
""" """
@ -142,186 +144,219 @@ SOCKET.IO LISTENERS
""" """
@socketio.on('requestAllImages') @socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(f">> System config requested")
config = get_system_config()
socketio.emit("systemConfig", config)
@socketio.on("requestAllImages")
def handle_request_all_images(): def handle_request_all_images():
print(f'>> All images requested') print(f">> All images requested")
parser = create_cmd_parser()
paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png"))) paths = list(filter(os.path.isfile, glob.glob(result_path + "*.png")))
paths.sort(key=lambda x: os.path.getmtime(x)) paths.sort(key=lambda x: os.path.getmtime(x))
image_array = [] image_array = []
for path in paths: for path in paths:
# image = Image.open(path) metadata = retrieve_metadata(path)
all_metadata = retrieve_metadata(path) image_array.append({"url": path, "metadata": metadata["sd-metadata"]})
if 'Dream' in all_metadata and not all_metadata['sd-metadata']: socketio.emit("galleryImages", {"images": image_array})
metadata = vars(parser.parse_args(shlex.split(all_metadata['Dream'])))
else:
metadata = all_metadata['sd-metadata']
image_array.append({'path': path, 'metadata': metadata})
socketio.emit('galleryImages', {'images': image_array})
eventlet.sleep(0) eventlet.sleep(0)
@socketio.on('generateImage') @socketio.on("generateImage")
def handle_generate_image_event(generation_parameters, esrgan_parameters, gfpgan_parameters): def handle_generate_image_event(
print(f'>> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}') generation_parameters, esrgan_parameters, gfpgan_parameters
generate_images( ):
generation_parameters, print(
esrgan_parameters, f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}"
gfpgan_parameters
) )
generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
@socketio.on('runESRGAN') @socketio.on("runESRGAN")
def handle_run_esrgan_event(original_image, esrgan_parameters): def handle_run_esrgan_event(original_image, esrgan_parameters):
print(f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}') print(
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
)
progress = { progress = {
'currentStep': 1, "currentStep": 1,
'totalSteps': 1, "totalSteps": 1,
'currentIteration': 1, "currentIteration": 1,
'totalIterations': 1, "totalIterations": 1,
'currentStatus': 'Preparing', "currentStatus": "Preparing",
'isProcessing': True, "isProcessing": True,
'currentStatusHasSteps': False "currentStatusHasSteps": False,
} }
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress['currentStatus'] = 'Upscaling' progress["currentStatus"] = "Upscaling"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
upsampler_scale=esrgan_parameters['upscale'][0], upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters['upscale'][1], strength=esrgan_parameters["upscale"][1],
seed=seed seed=seed,
) )
progress['currentStatus'] = 'Saving image' progress["currentStatus"] = "Saving image"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
esrgan_parameters['seed'] = seed esrgan_parameters["seed"] = seed
path = save_image(image, esrgan_parameters, result_path, postprocessing='esrgan') metadata = parameters_to_post_processed_image_metadata(
parameters=esrgan_parameters,
original_image_path=original_image["url"],
type="esrgan",
)
command = parameters_to_command(esrgan_parameters) command = parameters_to_command(esrgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="esrgan")
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished' progress["currentStatus"] = "Finished"
progress['currentStep'] = 0 progress["currentStep"] = 0
progress['totalSteps'] = 0 progress["totalSteps"] = 0
progress['currentIteration'] = 0 progress["currentIteration"] = 0
progress['totalIterations'] = 0 progress["totalIterations"] = 0
progress['isProcessing'] = False progress["isProcessing"] = False
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
socketio.emit( socketio.emit(
'esrganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': esrgan_parameters}) "esrganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on("runGFPGAN")
@socketio.on('runGFPGAN')
def handle_run_gfpgan_event(original_image, gfpgan_parameters): def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}') print(
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
)
progress = { progress = {
'currentStep': 1, "currentStep": 1,
'totalSteps': 1, "totalSteps": 1,
'currentIteration': 1, "currentIteration": 1,
'totalIterations': 1, "totalIterations": 1,
'currentStatus': 'Preparing', "currentStatus": "Preparing",
'isProcessing': True, "isProcessing": True,
'currentStatusHasSteps': False "currentStatusHasSteps": False,
} }
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = Image.open(original_image["url"]) image = Image.open(original_image["url"])
seed = original_image['metadata']['seed'] if 'seed' in original_image['metadata'] else 'unknown_seed' seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress['currentStatus'] = 'Fixing faces' progress["currentStatus"] = "Fixing faces"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['gfpgan_strength'], strength=gfpgan_parameters["gfpgan_strength"],
seed=seed, seed=seed,
upsampler_scale=1 upsampler_scale=1,
) )
progress['currentStatus'] = 'Saving image' progress["currentStatus"] = "Saving image"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
gfpgan_parameters['seed'] = seed gfpgan_parameters["seed"] = seed
path = save_image(image, gfpgan_parameters, result_path, postprocessing='gfpgan') metadata = parameters_to_post_processed_image_metadata(
parameters=gfpgan_parameters,
original_image_path=original_image["url"],
type="gfpgan",
)
command = parameters_to_command(gfpgan_parameters) command = parameters_to_command(gfpgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="gfpgan")
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}') write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress['currentStatus'] = 'Finished' progress["currentStatus"] = "Finished"
progress['currentStep'] = 0 progress["currentStep"] = 0
progress['totalSteps'] = 0 progress["totalSteps"] = 0
progress['currentIteration'] = 0 progress["currentIteration"] = 0
progress['totalIterations'] = 0 progress["totalIterations"] = 0
progress['isProcessing'] = False progress["isProcessing"] = False
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
socketio.emit( socketio.emit(
'gfpganResult', {'url': os.path.relpath(path), 'uuid': original_image['uuid'], 'metadata': gfpgan_parameters}) "gfpganResult",
{
"url": os.path.relpath(path),
"metadata": metadata,
},
)
@socketio.on('cancel') @socketio.on("cancel")
def handle_cancel(): def handle_cancel():
print(f'>> Cancel processing requested') print(f">> Cancel processing requested")
canceled.set() canceled.set()
socketio.emit('processingCanceled') socketio.emit("processingCanceled")
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('deleteImage') @socketio.on("deleteImage")
def handle_delete_image(path, uuid): def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"') print(f'>> Delete requested "{path}"')
send2trash(path) send2trash(path)
socketio.emit('imageDeleted', {'url': path, 'uuid': uuid}) socketio.emit("imageDeleted", {"url": path, "uuid": uuid})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('uploadInitialImage') @socketio.on("uploadInitialImage")
def handle_upload_initial_image(bytes, name): def handle_upload_initial_image(bytes, name):
print(f'>> Init image upload requested "{name}"') print(f'>> Init image upload requested "{name}"')
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(init_image_path, name) file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
socketio.emit('initialImageUploaded', {'url': file_path, 'uuid': ''}) socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""})
# TODO: I think this needs a safety mechanism. # TODO: I think this needs a safety mechanism.
@socketio.on('uploadMaskImage') @socketio.on("uploadMaskImage")
def handle_upload_mask_image(bytes, name): def handle_upload_mask_image(bytes, name):
print(f'>> Mask image upload requested "{name}"') print(f'>> Mask image upload requested "{name}"')
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(mask_image_path, name) file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb") newFile = open(file_path, "wb")
newFile.write(bytes) newFile.write(bytes)
socketio.emit('maskImageUploaded', {'url': file_path, 'uuid': ''}) socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""})
""" """
@ -329,50 +364,175 @@ END SOCKET.IO LISTENERS
""" """
""" """
ADDITIONAL FUNCTIONS ADDITIONAL FUNCTIONS
""" """
def get_system_config():
return {
"model": "stable diffusion",
"model_id": model,
"model_hash": generate.model_hash,
"app_id": APP_ID,
"app_version": APP_VERSION,
}
def parameters_to_post_processed_image_metadata(parameters, original_image_path, type):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
orig_hash = calculate_init_img_hash(original_image_path)
image = {"orig_path": original_image_path, "orig_hash": orig_hash}
if type == "esrgan":
image["type"] = "esrgan"
image["scale"] = parameters["upscale"][0]
image["strength"] = parameters["upscale"][1]
elif type == "gfpgan":
image["type"] = "gfpgan"
image["strength"] = parameters["gfpgan_strength"]
else:
raise TypeError(f"Invalid type: {type}")
metadata["image"] = image
return metadata
def parameters_to_generated_image_metadata(parameters):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = [
"type",
"postprocessing",
"sampler",
"prompt",
"seed",
"variations",
"steps",
"cfg_scale",
"step_number",
"width",
"height",
"extra",
"seamless",
]
rfc_dict = {}
for item in parameters.items():
key, value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
postprocessing = []
# 'postprocessing' is either null or an
if "gfpgan_strength" in parameters:
postprocessing.append(
{"type": "gfpgan", "strength": float(parameters["gfpgan_strength"])}
)
if "upscale" in parameters:
postprocessing.append(
{
"type": "esrgan",
"scale": int(parameters["upscale"][0]),
"strength": float(parameters["upscale"][1]),
}
)
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"])
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]
]
rfc_dict["variations"] = variations
if "init_img" in parameters:
rfc_dict["type"] = "img2img"
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"])
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
if "init_mask" in parameters:
rfc_dict["mask_hash"] = calculate_init_img_hash(
parameters["init_mask"]
) # TODO: Noncompliant
rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant
else:
rfc_dict["type"] = "txt2img"
metadata["image"] = rfc_dict
return metadata
def make_unique_init_image_filename(name): def make_unique_init_image_filename(name):
uuid = uuid4().hex uuid = uuid4().hex
split = os.path.splitext(name) split = os.path.splitext(name)
name = f'{split[0]}.{uuid}{split[1]}' name = f"{split[0]}.{uuid}{split[1]}"
return name return name
def write_log_message(message, log_path=log_path): def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file""" """Logs the filename and parameters used to generate or process that image to log file"""
message = f'{message}\n' message = f"{message}\n"
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, "a", encoding="utf-8") as file:
file.writelines(message) file.writelines(message)
def save_image(image, parameters, output_dir, step_index=None, postprocessing=False): def save_image(
seed = parameters['seed'] if 'seed' in parameters else 'unknown_seed' image, command, metadata, output_dir, step_index=None, postprocessing=False
):
pngwriter = PngWriter(output_dir) pngwriter = PngWriter(output_dir)
prefix = pngwriter.unique_prefix() prefix = pngwriter.unique_prefix()
filename = f'{prefix}.{seed}' seed = "unknown_seed"
if "image" in metadata:
if "seed" in metadata["image"]:
seed = metadata["image"]["seed"]
filename = f"{prefix}.{seed}"
if step_index: if step_index:
filename += f'.{step_index}' filename += f".{step_index}"
if postprocessing: if postprocessing:
filename += f'.postprocessed' filename += f".postprocessed"
filename += '.png' filename += ".png"
command = parameters_to_command(parameters) path = pngwriter.save_image_and_prompt_to_png(
image=image, dream_prompt=command, metadata=metadata, name=filename
path = pngwriter.save_image_and_prompt_to_png(image, command, metadata=parameters, name=filename) )
return path return path
def calculate_real_steps(steps, strength, has_init_image): def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters): def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear() canceled.clear()
@ -385,40 +545,40 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
If the init/mask image doesn't exist in the init_image_path/mask_image_path, If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there. make a unique filename for it and copy it there.
""" """
if ('init_img' in generation_parameters): if "init_img" in generation_parameters:
filename = os.path.basename(generation_parameters['init_img']) filename = os.path.basename(generation_parameters["init_img"])
if not os.path.exists(os.path.join(init_image_path, filename)): if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename) unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename) new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path) shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters['init_img'] = new_path generation_parameters["init_img"] = new_path
if ('init_mask' in generation_parameters): if "init_mask" in generation_parameters:
filename = os.path.basename(generation_parameters['init_mask']) filename = os.path.basename(generation_parameters["init_mask"])
if not os.path.exists(os.path.join(mask_image_path, filename)): if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename) unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename) new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters['init_img'], new_path) shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters['init_mask'] = new_path generation_parameters["init_mask"] = new_path
totalSteps = calculate_real_steps( totalSteps = calculate_real_steps(
steps=generation_parameters['steps'], steps=generation_parameters["steps"],
strength=generation_parameters['strength'] if 'strength' in generation_parameters else None, strength=generation_parameters["strength"]
has_init_image='init_img' in generation_parameters if "strength" in generation_parameters
) else None,
has_init_image="init_img" in generation_parameters,
)
progress = { progress = {
'currentStep': 1, "currentStep": 1,
'totalSteps': totalSteps, "totalSteps": totalSteps,
'currentIteration': 1, "currentIteration": 1,
'totalIterations': generation_parameters['iterations'], "totalIterations": generation_parameters["iterations"],
'currentStatus': 'Preparing', "currentStatus": "Preparing",
'isProcessing': True, "isProcessing": True,
'currentStatusHasSteps': False "currentStatusHasSteps": False,
} }
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
def image_progress(sample, step): def image_progress(sample, step):
@ -429,18 +589,26 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
nonlocal generation_parameters nonlocal generation_parameters
nonlocal progress nonlocal progress
progress['currentStep'] = step + 1 progress["currentStep"] = step + 1
progress['currentStatus'] = 'Generating' progress["currentStatus"] = "Generating"
progress['currentStatusHasSteps'] = True progress["currentStatusHasSteps"] = True
if generation_parameters["progress_images"] and step % 5 == 0 and step < generation_parameters['steps'] - 1: if (
image = model.sample_to_image(sample) generation_parameters["progress_images"]
path = save_image(image, generation_parameters, intermediate_path, step_index) and step % 5 == 0
and step < generation_parameters["steps"] - 1
):
image = generate.sample_to_image(sample)
path = save_image(
image, generation_parameters, intermediate_path, step_index
)
step_index += 1 step_index += 1
socketio.emit('intermediateResult', { socketio.emit(
'url': os.path.relpath(path), 'metadata': generation_parameters}) "intermediateResult",
socketio.emit('progressUpdate', progress) {"url": os.path.relpath(path), "metadata": generation_parameters},
)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
def image_done(image, seed): def image_done(image, seed):
@ -451,79 +619,88 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
step_index = 1 step_index = 1
progress['currentStatus'] = 'Generation complete' progress["currentStatus"] = "Generation complete"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
all_parameters = generation_parameters all_parameters = generation_parameters
postprocessing = False postprocessing = False
if esrgan_parameters: if esrgan_parameters:
progress['currentStatus'] = 'Upscaling' progress["currentStatus"] = "Upscaling"
progress['currentStatusHasSteps'] = False progress["currentStatusHasSteps"] = False
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = real_esrgan_upscale( image = real_esrgan_upscale(
image=image, image=image,
strength=esrgan_parameters['strength'], strength=esrgan_parameters["strength"],
upsampler_scale=esrgan_parameters['level'], upsampler_scale=esrgan_parameters["level"],
seed=seed seed=seed,
) )
postprocessing = True postprocessing = True
all_parameters["upscale"] = [esrgan_parameters['level'], esrgan_parameters['strength']] all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters["strength"],
]
if gfpgan_parameters: if gfpgan_parameters:
progress['currentStatus'] = 'Fixing faces' progress["currentStatus"] = "Fixing faces"
progress['currentStatusHasSteps'] = False progress["currentStatusHasSteps"] = False
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
image = run_gfpgan( image = run_gfpgan(
image=image, image=image,
strength=gfpgan_parameters['strength'], strength=gfpgan_parameters["strength"],
seed=seed, seed=seed,
upsampler_scale=1, upsampler_scale=1,
) )
postprocessing = True postprocessing = True
all_parameters["gfpgan_strength"] = gfpgan_parameters['strength'] all_parameters["gfpgan_strength"] = gfpgan_parameters["strength"]
all_parameters['seed'] = seed all_parameters["seed"] = seed
progress['currentStatus'] = 'Saving image' progress["currentStatus"] = "Saving image"
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
path = save_image(image, all_parameters, result_path, postprocessing=postprocessing) metadata = parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters) command = parameters_to_command(all_parameters)
print(f'Image generated: "{path}"') path = save_image(
image, command, metadata, result_path, postprocessing=postprocessing
)
print(f'>> Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}') write_log_message(f'[Generated] "{path}": {command}')
if (progress['totalIterations'] > progress['currentIteration']): if progress["totalIterations"] > progress["currentIteration"]:
progress['currentStep'] = 1 progress["currentStep"] = 1
progress['currentIteration'] +=1 progress["currentIteration"] += 1
progress['currentStatus'] = 'Iteration finished' progress["currentStatus"] = "Iteration finished"
progress['currentStatusHasSteps'] = False progress["currentStatusHasSteps"] = False
else: else:
progress['currentStep'] = 0 progress["currentStep"] = 0
progress['totalSteps'] = 0 progress["totalSteps"] = 0
progress['currentIteration'] = 0 progress["currentIteration"] = 0
progress['totalIterations'] = 0 progress["totalIterations"] = 0
progress['currentStatus'] = 'Finished' progress["currentStatus"] = "Finished"
progress['isProcessing'] = False progress["isProcessing"] = False
socketio.emit('progressUpdate', progress) socketio.emit("progressUpdate", progress)
eventlet.sleep(0) eventlet.sleep(0)
socketio.emit( socketio.emit(
'generationResult', {'url': os.path.relpath(path), 'metadata': all_parameters}) "generationResult",
{"url": os.path.relpath(path), "metadata": metadata},
)
eventlet.sleep(0) eventlet.sleep(0)
try: try:
model.prompt2image( generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=image_progress, step_callback=image_progress,
image_callback=image_done image_callback=image_done,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@ -531,7 +708,7 @@ def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
except CanceledException: except CanceledException:
pass pass
except Exception as e: except Exception as e:
socketio.emit('error', {'message': (str(e))}) socketio.emit("error", {"message": (str(e))})
print("\n") print("\n")
traceback.print_exc() traceback.print_exc()
print("\n") print("\n")
@ -542,6 +719,6 @@ END ADDITIONAL FUNCTIONS
""" """
if __name__ == '__main__': if __name__ == "__main__":
print(f'Starting server at http://{host}:{port}') print(f">> Starting server at http://{host}:{port}")
socketio.run(app, host=host, port=port) socketio.run(app, host=host, port=port)

694
frontend/dist/assets/index.727a397b.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Stable Diffusion Dream Server</title> <title>Stable Diffusion Dream Server</title>
<script type="module" crossorigin src="/assets/index.de730902.js"></script> <script type="module" crossorigin src="/assets/index.727a397b.js"></script>
<link rel="stylesheet" href="/assets/index.447eb2a9.css"> <link rel="stylesheet" href="/assets/index.447eb2a9.css">
</head> </head>
<body> <body>

View File

@ -1,7 +1,7 @@
{ {
"name": "sdui", "name": "invoke-ai-ui",
"private": true, "private": true,
"version": "0.0.0", "version": "0.0.1",
"type": "module", "type": "module",
"scripts": { "scripts": {
"dev": "vite dev", "dev": "vite dev",
@ -10,6 +10,7 @@
"preview": "vite preview" "preview": "vite preview"
}, },
"dependencies": { "dependencies": {
"@chakra-ui/icons": "^2.0.10",
"@chakra-ui/react": "^2.3.1", "@chakra-ui/react": "^2.3.1",
"@emotion/react": "^11.10.4", "@emotion/react": "^11.10.4",
"@emotion/styled": "^11.10.4", "@emotion/styled": "^11.10.4",

View File

@ -2,15 +2,15 @@ import { Grid, GridItem } from '@chakra-ui/react';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay'; import CurrentImageDisplay from '../features/gallery/CurrentImageDisplay';
import ImageGallery from '../features/gallery/ImageGallery'; import ImageGallery from '../features/gallery/ImageGallery';
import ProgressBar from '../features/header/ProgressBar'; import ProgressBar from '../features/system/ProgressBar';
import SiteHeader from '../features/header/SiteHeader'; import SiteHeader from '../features/system/SiteHeader';
import OptionsAccordion from '../features/sd/OptionsAccordion'; import OptionsAccordion from '../features/options/OptionsAccordion';
import ProcessButtons from '../features/sd/ProcessButtons'; import ProcessButtons from '../features/options/ProcessButtons';
import PromptInput from '../features/sd/PromptInput'; import PromptInput from '../features/options/PromptInput';
import LogViewer from '../features/system/LogViewer'; import LogViewer from '../features/system/LogViewer';
import Loading from '../Loading'; import Loading from '../Loading';
import { useAppDispatch } from './store'; import { useAppDispatch } from './store';
import { requestAllImages } from './socketio/actions'; import { requestAllImages, requestSystemConfig } from './socketio/actions';
const App = () => { const App = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -19,6 +19,7 @@ const App = () => {
// Load images from the gallery once // Load images from the gallery once
useEffect(() => { useEffect(() => {
dispatch(requestAllImages()); dispatch(requestAllImages());
dispatch(requestSystemConfig());
setIsReady(true); setIsReady(true);
}, [dispatch]); }, [dispatch]);

170
frontend/src/app/invokeai.d.ts vendored Normal file
View File

@ -0,0 +1,170 @@
/**
* Types for images, the things they are made of, and the things
* they make up.
*
* Generated images are txt2img and img2img images. They may have
* had additional postprocessing done on them when they were first
* generated.
*
* Postprocessed images are images which were not generated here
* but only postprocessed by the app. They only get postprocessing
* metadata and have a different image type, e.g. 'esrgan' or
* 'gfpgan'.
*/
/**
* TODO:
* Once an image has been generated, if it is postprocessed again,
* additional postprocessing steps are added to its postprocessing
* array.
*
* TODO: Better documentation of types.
*/
export declare type PromptItem = {
prompt: string;
weight: number;
};
export declare type Prompt = Array<PromptItem>;
export declare type SeedWeightPair = {
seed: number;
weight: number;
};
export declare type SeedWeights = Array<SeedWeightPair>;
// All generated images contain these metadata.
export declare type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
sampler:
| 'ddim'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'plms';
prompt: Prompt;
seed: number;
variations: SeedWeights;
steps: number;
cfg_scale: number;
width: number;
height: number;
seamless: boolean;
extra: null | Record<string, never>; // Pending development of RFC #266
};
// txt2img and img2img images have some unique attributes.
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
type: 'txt2img';
};
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
fit: boolean;
init_image_path: string;
mask_image_path?: string;
};
// Superset of generated image metadata types.
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export declare type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
};
export declare type GFPGANMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan';
strength: number;
};
// Superset of all postprocessed image metadata types..
export declare type PostProcessedImageMetadata =
| ESRGANMetadata
| GFPGANMetadata;
// Metadata includes the system config and image metadata.
export declare type Metadata = SystemConfig & {
image: GeneratedImageMetadata | PostProcessedImageMetadata;
};
// An Image has a UUID, url (path?) and Metadata.
export declare type Image = {
uuid: string;
url: string;
metadata: Metadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
};
/**
* Types related to the system status.
*/
// This represents the processing status of the backend.
export declare type SystemStatus = {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
};
export declare type SystemConfig = {
model: string;
model_id: string;
model_hash: string;
app_id: string;
app_version: string;
};
/**
* These types type data received from the server via socketio.
*/
export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = {
url: string;
metadata: Metadata;
};
export declare type ErrorResponse = {
message: string;
additionalData?: string;
};
export declare type GalleryImagesResponse = {
images: Array<{ url: string; metadata: Metadata }>;
};
export declare type ImageUrlAndUuidResponse = {
uuid: string;
url: string;
};
export declare type ImageUrlResponse = {
url: string;
};

View File

@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { SDImage } from '../../features/gallery/gallerySlice'; import * as InvokeAI from '../invokeai';
/** /**
* We can't use redux-toolkit's createSlice() to make these actions, * We can't use redux-toolkit's createSlice() to make these actions,
@ -9,9 +9,9 @@ import { SDImage } from '../../features/gallery/gallerySlice';
*/ */
export const generateImage = createAction<undefined>('socketio/generateImage'); export const generateImage = createAction<undefined>('socketio/generateImage');
export const runESRGAN = createAction<SDImage>('socketio/runESRGAN'); export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runGFPGAN = createAction<SDImage>('socketio/runGFPGAN'); export const runGFPGAN = createAction<InvokeAI.Image>('socketio/runGFPGAN');
export const deleteImage = createAction<SDImage>('socketio/deleteImage'); export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestAllImages = createAction<undefined>( export const requestAllImages = createAction<undefined>(
'socketio/requestAllImages' 'socketio/requestAllImages'
); );
@ -22,3 +22,5 @@ export const uploadInitialImage = createAction<File>(
'socketio/uploadInitialImage' 'socketio/uploadInitialImage'
); );
export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage'); export const uploadMaskImage = createAction<File>('socketio/uploadMaskImage');
export const requestSystemConfig = createAction<undefined>('socketio/requestSystemConfig');

View File

@ -2,11 +2,11 @@ import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat'; import dateFormat from 'dateformat';
import { Socket } from 'socket.io-client'; import { Socket } from 'socket.io-client';
import { frontendToBackendParameters } from '../../common/util/parameterTranslation'; import { frontendToBackendParameters } from '../../common/util/parameterTranslation';
import { SDImage } from '../../features/gallery/gallerySlice';
import { import {
addLogEntry, addLogEntry,
setIsProcessing, setIsProcessing,
} from '../../features/system/systemSlice'; } from '../../features/system/systemSlice';
import * as InvokeAI from '../invokeai';
/** /**
* Returns an object containing all functions which use `socketio.emit()`. * Returns an object containing all functions which use `socketio.emit()`.
@ -24,7 +24,7 @@ const makeSocketIOEmitters = (
dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
const { generationParameters, esrganParameters, gfpganParameters } = const { generationParameters, esrganParameters, gfpganParameters } =
frontendToBackendParameters(getState().sd, getState().system); frontendToBackendParameters(getState().options, getState().system);
socketio.emit( socketio.emit(
'generateImage', 'generateImage',
@ -44,9 +44,9 @@ const makeSocketIOEmitters = (
}) })
); );
}, },
emitRunESRGAN: (imageToProcess: SDImage) => { emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().sd; const { upscalingLevel, upscalingStrength } = getState().options;
const esrganParameters = { const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength], upscale: [upscalingLevel, upscalingStrength],
}; };
@ -61,9 +61,9 @@ const makeSocketIOEmitters = (
}) })
); );
}, },
emitRunGFPGAN: (imageToProcess: SDImage) => { emitRunGFPGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
const { gfpganStrength } = getState().sd; const { gfpganStrength } = getState().options;
const gfpganParameters = { const gfpganParameters = {
gfpgan_strength: gfpganStrength, gfpgan_strength: gfpganStrength,
@ -79,7 +79,7 @@ const makeSocketIOEmitters = (
}) })
); );
}, },
emitDeleteImage: (imageToDelete: SDImage) => { emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid } = imageToDelete; const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid); socketio.emit('deleteImage', url, uuid);
}, },
@ -95,6 +95,9 @@ const makeSocketIOEmitters = (
emitUploadMaskImage: (file: File) => { emitUploadMaskImage: (file: File) => {
socketio.emit('uploadMaskImage', file, file.name); socketio.emit('uploadMaskImage', file, file.name);
}, },
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig')
}
}; };
}; };

View File

@ -2,38 +2,29 @@ import { AnyAction, MiddlewareAPI, Dispatch } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import dateFormat from 'dateformat'; import dateFormat from 'dateformat';
import * as InvokeAI from '../invokeai';
import { import {
addLogEntry, addLogEntry,
setIsConnected, setIsConnected,
setIsProcessing, setIsProcessing,
SystemStatus,
setSystemStatus, setSystemStatus,
setCurrentStatus, setCurrentStatus,
setSystemConfig,
} from '../../features/system/systemSlice'; } from '../../features/system/systemSlice';
import type {
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { backendToFrontendParameters } from '../../common/util/parameterTranslation';
import { import {
addImage, addImage,
clearIntermediateImage, clearIntermediateImage,
removeImage, removeImage,
SDImage,
setGalleryImages, setGalleryImages,
setIntermediateImage, setIntermediateImage,
} from '../../features/gallery/gallerySlice'; } from '../../features/gallery/gallerySlice';
import { setInitialImagePath, setMaskPath } from '../../features/sd/sdSlice'; import {
setInitialImagePath,
setMaskPath,
} from '../../features/options/optionsSlice';
/** /**
* Returns an object containing listener callbacks for socketio events. * Returns an object containing listener callbacks for socketio events.
@ -79,18 +70,16 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'generationResult' event. * Callback to run when we receive a 'generationResult' event.
*/ */
onGenerationResult: (data: ServerGenerationResult) => { onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
try { try {
const { url, metadata } = data; const { url, metadata } = data;
const newUuid = uuidv4(); const newUuid = uuidv4();
const translatedMetadata = backendToFrontendParameters(metadata);
dispatch( dispatch(
addImage({ addImage({
uuid: newUuid, uuid: newUuid,
url, url,
metadata: translatedMetadata, metadata: metadata,
}) })
); );
dispatch( dispatch(
@ -107,7 +96,7 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'intermediateResult' event. * Callback to run when we receive a 'intermediateResult' event.
*/ */
onIntermediateResult: (data: ServerIntermediateResult) => { onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
try { try {
const uuid = uuidv4(); const uuid = uuidv4();
const { url, metadata } = data; const { url, metadata } = data;
@ -132,31 +121,15 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive an 'esrganResult' event. * Callback to run when we receive an 'esrganResult' event.
*/ */
onESRGANResult: (data: ServerESRGANResult) => { onESRGANResult: (data: InvokeAI.ImageResultResponse) => {
try { try {
const { url, uuid, metadata } = data; const { url, metadata } = data;
const newUuid = uuidv4();
// This image was only ESRGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the ESRGAN-related fields
newMetadata.shouldRunESRGAN = true;
newMetadata.upscalingLevel = metadata.upscale[0];
newMetadata.upscalingStrength = metadata.upscale[1];
dispatch( dispatch(
addImage({ addImage({
uuid: newUuid, uuid: uuidv4(),
url, url,
metadata: newMetadata, metadata,
}) })
); );
@ -174,30 +147,15 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'gfpganResult' event. * Callback to run when we receive a 'gfpganResult' event.
*/ */
onGFPGANResult: (data: ServerGFPGANResult) => { onGFPGANResult: (data: InvokeAI.ImageResultResponse) => {
try { try {
const { url, uuid, metadata } = data; const { url, metadata } = data;
const newUuid = uuidv4();
// This image was only GFPGAN'd, grab the original image's metadata
const originalImage = getState().gallery.images.find(
(i: SDImage) => i.uuid === uuid
);
// Retain the original metadata
const newMetadata = {
...originalImage.metadata,
};
// Update the GFPGAN-related fields
newMetadata.shouldRunGFPGAN = true;
newMetadata.gfpganStrength = metadata.gfpgan_strength;
dispatch( dispatch(
addImage({ addImage({
uuid: newUuid, uuid: uuidv4(),
url, url,
metadata: newMetadata, metadata,
}) })
); );
@ -215,7 +173,7 @@ const makeSocketIOListeners = (
* Callback to run when we receive a 'progressUpdate' event. * Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases * TODO: Add additional progress phases
*/ */
onProgressUpdate: (data: SystemStatus) => { onProgressUpdate: (data: InvokeAI.SystemStatus) => {
try { try {
dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data)); dispatch(setSystemStatus(data));
@ -226,7 +184,7 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'progressUpdate' event. * Callback to run when we receive a 'progressUpdate' event.
*/ */
onError: (data: ServerError) => { onError: (data: InvokeAI.ErrorResponse) => {
const { message, additionalData } = data; const { message, additionalData } = data;
if (additionalData) { if (additionalData) {
@ -250,13 +208,14 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'galleryImages' event. * Callback to run when we receive a 'galleryImages' event.
*/ */
onGalleryImages: (data: ServerGalleryImages) => { onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
const { images } = data; const { images } = data;
const preparedImages = images.map((image): SDImage => { const preparedImages = images.map((image): InvokeAI.Image => {
const { url, metadata } = image;
return { return {
uuid: uuidv4(), uuid: uuidv4(),
url: image.path, url,
metadata: backendToFrontendParameters(image.metadata), metadata,
}; };
}); });
dispatch(setGalleryImages(preparedImages)); dispatch(setGalleryImages(preparedImages));
@ -296,7 +255,7 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'imageDeleted' event. * Callback to run when we receive a 'imageDeleted' event.
*/ */
onImageDeleted: (data: ServerImageUrlAndUuid) => { onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => {
const { url, uuid } = data; const { url, uuid } = data;
dispatch(removeImage(uuid)); dispatch(removeImage(uuid));
dispatch( dispatch(
@ -309,7 +268,7 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'initialImageUploaded' event. * Callback to run when we receive a 'initialImageUploaded' event.
*/ */
onInitialImageUploaded: (data: ServerImageUrl) => { onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data; const { url } = data;
dispatch(setInitialImagePath(url)); dispatch(setInitialImagePath(url));
dispatch( dispatch(
@ -322,7 +281,7 @@ const makeSocketIOListeners = (
/** /**
* Callback to run when we receive a 'maskImageUploaded' event. * Callback to run when we receive a 'maskImageUploaded' event.
*/ */
onMaskImageUploaded: (data: ServerImageUrl) => { onMaskImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data; const { url } = data;
dispatch(setMaskPath(url)); dispatch(setMaskPath(url));
dispatch( dispatch(
@ -332,6 +291,9 @@ const makeSocketIOListeners = (
}) })
); );
}, },
onSystemConfig: (data: InvokeAI.SystemConfig) => {
dispatch(setSystemConfig(data));
},
}; };
}; };

View File

@ -4,18 +4,23 @@ import { io } from 'socket.io-client';
import makeSocketIOListeners from './listeners'; import makeSocketIOListeners from './listeners';
import makeSocketIOEmitters from './emitters'; import makeSocketIOEmitters from './emitters';
import type { import * as InvokeAI from '../invokeai';
ServerGenerationResult,
ServerESRGANResult,
ServerGFPGANResult,
ServerIntermediateResult,
ServerError,
ServerGalleryImages,
ServerImageUrlAndUuid,
ServerImageUrl,
} from './types';
import { SystemStatus } from '../../features/system/systemSlice';
/**
* Creates a socketio middleware to handle communication with server.
*
* Special `socketio/actionName` actions are created in actions.ts and
* exported for use by the application, which treats them like any old
* action, using `dispatch` to dispatch them.
*
* These actions are intercepted here, where `socketio.emit()` calls are
* made on their behalf - see `emitters.ts`. The emitter functions
* are the outbound communication to the server.
*
* Listeners are also established here - see `listeners.ts`. The listener
* functions receive communication from the server and usually dispatch
* some new action to handle whatever data was sent from the server.
*/
export const socketioMiddleware = () => { export const socketioMiddleware = () => {
const { hostname, port } = new URL(window.location.href); const { hostname, port } = new URL(window.location.href);
@ -38,6 +43,7 @@ export const socketioMiddleware = () => {
onImageDeleted, onImageDeleted,
onInitialImageUploaded, onInitialImageUploaded,
onMaskImageUploaded, onMaskImageUploaded,
onSystemConfig,
} = makeSocketIOListeners(store); } = makeSocketIOListeners(store);
const { const {
@ -49,6 +55,7 @@ export const socketioMiddleware = () => {
emitCancelProcessing, emitCancelProcessing,
emitUploadInitialImage, emitUploadInitialImage,
emitUploadMaskImage, emitUploadMaskImage,
emitRequestSystemConfig,
} = makeSocketIOEmitters(store, socketio); } = makeSocketIOEmitters(store, socketio);
/** /**
@ -60,29 +67,29 @@ export const socketioMiddleware = () => {
socketio.on('disconnect', () => onDisconnect()); socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: ServerError) => onError(data)); socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
socketio.on('generationResult', (data: ServerGenerationResult) => socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
onGenerationResult(data) onGenerationResult(data)
); );
socketio.on('esrganResult', (data: ServerESRGANResult) => socketio.on('esrganResult', (data: InvokeAI.ImageResultResponse) =>
onESRGANResult(data) onESRGANResult(data)
); );
socketio.on('gfpganResult', (data: ServerGFPGANResult) => socketio.on('gfpganResult', (data: InvokeAI.ImageResultResponse) =>
onGFPGANResult(data) onGFPGANResult(data)
); );
socketio.on('intermediateResult', (data: ServerIntermediateResult) => socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
onIntermediateResult(data) onIntermediateResult(data)
); );
socketio.on('progressUpdate', (data: SystemStatus) => socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
onProgressUpdate(data) onProgressUpdate(data)
); );
socketio.on('galleryImages', (data: ServerGalleryImages) => socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
onGalleryImages(data) onGalleryImages(data)
); );
@ -90,18 +97,22 @@ export const socketioMiddleware = () => {
onProcessingCanceled(); onProcessingCanceled();
}); });
socketio.on('imageDeleted', (data: ServerImageUrlAndUuid) => { socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => {
onImageDeleted(data); onImageDeleted(data);
}); });
socketio.on('initialImageUploaded', (data: ServerImageUrl) => { socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onInitialImageUploaded(data); onInitialImageUploaded(data);
}); });
socketio.on('maskImageUploaded', (data: ServerImageUrl) => { socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onMaskImageUploaded(data); onMaskImageUploaded(data);
}); });
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
onSystemConfig(data);
});
areListenersSet = true; areListenersSet = true;
} }
@ -148,6 +159,11 @@ export const socketioMiddleware = () => {
emitUploadMaskImage(action.payload); emitUploadMaskImage(action.payload);
break; break;
} }
case 'socketio/requestSystemConfig': {
emitRequestSystemConfig();
break;
}
} }
next(action); next(action);

View File

@ -1,46 +0,0 @@
/**
* Interfaces used by the socketio middleware.
*/
export declare interface ServerGenerationResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerESRGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerGFPGANResult {
url: string;
uuid: string;
metadata: { [key: string]: any };
}
export declare interface ServerIntermediateResult {
url: string;
metadata: { [key: string]: any };
}
export declare interface ServerError {
message: string;
additionalData?: string;
}
export declare interface ServerGalleryImages {
images: Array<{
path: string;
metadata: { [key: string]: any };
}>;
}
export declare interface ServerImageUrlAndUuid {
uuid: string;
url: string;
}
export declare interface ServerImageUrl {
url: string;
}

View File

@ -5,7 +5,7 @@ import type { TypedUseSelectorHook } from 'react-redux';
import { persistReducer } from 'redux-persist'; import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import sdReducer from '../features/sd/sdSlice'; import optionsReducer from '../features/options/optionsSlice';
import galleryReducer from '../features/gallery/gallerySlice'; import galleryReducer from '../features/gallery/gallerySlice';
import systemReducer from '../features/system/systemSlice'; import systemReducer from '../features/system/systemSlice';
import { socketioMiddleware } from './socketio/middleware'; import { socketioMiddleware } from './socketio/middleware';
@ -53,7 +53,7 @@ const systemPersistConfig = {
}; };
const reducers = combineReducers({ const reducers = combineReducers({
sd: sdReducer, options: optionsReducer,
gallery: galleryReducer, gallery: galleryReducer,
system: persistReducer(systemPersistConfig, systemReducer), system: persistReducer(systemPersistConfig, systemReducer),
}); });

View File

@ -33,5 +33,20 @@ export const theme = extendTheme({
fontWeight: 'light', fontWeight: 'light',
}, },
}, },
Button: {
variants: {
imageHoverIconButton: (props: StyleFunctionProps) => ({
bg: props.colorMode === 'dark' ? 'blackAlpha.700' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.700' : 'blackAlpha.700',
_hover: {
bg:
props.colorMode === 'dark' ? 'blackAlpha.800' : 'whiteAlpha.800',
color:
props.colorMode === 'dark' ? 'whiteAlpha.900' : 'blackAlpha.900',
},
}),
},
},
}, },
}); });

View File

@ -1,171 +0,0 @@
/**
* Defines common parameters required to generate an image.
* See #266 for the eventual maturation of this interface.
*/
interface CommonParameters {
/**
* The "txt2img" prompt. String. Minimum one character. No maximum.
*/
prompt: string;
/**
* The number of sampler steps. Integer. Minimum value 1. No maximum.
*/
steps: number;
/**
* Classifier-free guidance scale. Float. Minimum value 0. Maximum?
*/
cfgScale: number;
/**
* Height of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
height: number;
/**
* Width of output image in pixels. Integer. Minimum 64. Must be multiple of 64. No maximum.
*/
width: number;
/**
* Name of the sampler to use. String. Restricted values.
*/
sampler:
| 'ddim'
| 'plms'
| 'k_lms'
| 'k_dpm_2'
| 'k_dpm_2_a'
| 'k_euler'
| 'k_euler_a'
| 'k_heun';
/**
* Seed used for randomness. Integer. 0 --> 4294967295, inclusive.
*/
seed: number;
/**
* Flag to enable seamless tiling image generation. Boolean.
*/
seamless: boolean;
}
/**
* Defines parameters needed to use the "img2img" generation method.
*/
interface ImageToImageParameters {
/**
* Folder path to the image used as the initial image. String.
*/
initialImagePath: string;
/**
* Flag to enable the use of a mask image during "img2img" generations.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseMaskImage: boolean;
/**
* Folder path to the image used as a mask image. String.
*/
maskImagePath: string;
/**
* Strength of adherance to initial image. Float. 0 --> 1, exclusive.
*/
img2imgStrength: number;
/**
* Flag to enable the stretching of init image to desired output. Boolean.
*/
shouldFit: boolean;
}
/**
* Defines the parameters needed to generate variations.
*/
interface VariationParameters {
/**
* Variation amount. Float. 0 --> 1, exclusive.
* TODO: What does this really do?
*/
variationAmount: number;
/**
* List of seed-weight pairs formatted as "seed:weight,...".
* Seed is a valid seed. Weight is a float, 0 --> 1, exclusive.
* String, must be parseable into [[seed,weight],...] format.
*/
seedWeights: string;
}
/**
* Defines the parameters needed to use GFPGAN postprocessing.
*/
interface GFPGANParameters {
/**
* GFPGAN strength. Strength to apply face-fixing processing. Float. 0 --> 1, exclusive.
*/
gfpganStrength: number;
}
/**
* Defines the parameters needed to use ESRGAN postprocessing.
*/
interface ESRGANParameters {
/**
* ESRGAN strength. Strength to apply upscaling. Float. 0 --> 1, exclusive.
*/
esrganStrength: number;
/**
* ESRGAN upscaling scale. One of 2x | 4x. Represented as integer.
*/
esrganScale: 2 | 4;
}
/**
* Extends the generation and processing method parameters, adding flags to enable each.
*/
interface ProcessingParameters extends CommonParameters {
/**
* Flag to enable the generation of variations. Requires valid VariationParameters. Boolean.
*/
shouldGenerateVariations: boolean;
/**
* Variation parameters.
*/
variationParameters: VariationParameters;
/**
* Flag to enable the use of an initial image, i.e. to use "img2img" generation.
* Requires valid ImageToImageParameters. Boolean.
*/
shouldUseImageToImage: boolean;
/**
* ImageToImage parameters.
*/
imageToImageParameters: ImageToImageParameters;
/**
* Flag to enable GFPGAN postprocessing. Requires valid GFPGANParameters. Boolean.
*/
shouldRunGFPGAN: boolean;
/**
* GFPGAN parameters.
*/
gfpganParameters: GFPGANParameters;
/**
* Flag to enable ESRGAN postprocessing. Requires valid ESRGANParameters. Boolean.
*/
shouldRunESRGAN: boolean;
/**
* ESRGAN parameters.
*/
esrganParameters: GFPGANParameters;
}
/**
* Extends ProcessingParameters, adding items needed to request processing.
*/
interface ProcessingState extends ProcessingParameters {
/**
* Number of images to generate. Integer. Minimum 1.
*/
iterations: number;
/**
* Flag to enable the randomization of the seed on each generation. Boolean.
*/
shouldRandomizeSeed: boolean;
}
export {}

View File

@ -3,20 +3,20 @@ import { isEqual } from 'lodash';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useAppSelector } from '../../app/store'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice'; import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from '../../features/system/systemSlice'; import { SystemState } from '../../features/system/systemSlice';
import { validateSeedWeights } from '../util/seedWeightPairs'; import { validateSeedWeights } from '../util/seedWeightPairs';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
prompt: sd.prompt, prompt: options.prompt,
shouldGenerateVariations: sd.shouldGenerateVariations, shouldGenerateVariations: options.shouldGenerateVariations,
seedWeights: sd.seedWeights, seedWeights: options.seedWeights,
maskPath: sd.maskPath, maskPath: options.maskPath,
initialImagePath: sd.initialImagePath, initialImagePath: options.initialImagePath,
seed: sd.seed, seed: options.seed,
}; };
}, },
{ {
@ -53,7 +53,7 @@ const useCheckParameters = (): boolean => {
maskPath, maskPath,
initialImagePath, initialImagePath,
seed, seed,
} = useAppSelector(sdSelector); } = useAppSelector(optionsSelector);
const { isProcessing, isConnected } = useAppSelector(systemSelector); const { isProcessing, isConnected } = useAppSelector(systemSelector);

View File

@ -1,180 +1,182 @@
/* /*
These functions translate frontend state into parameters These functions translate frontend state into parameters
suitable for consumption by the backend, and vice-versa. suitable for consumption by the backend, and vice-versa.
*/ */
import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from "../../app/constants"; import { NUMPY_RAND_MAX, NUMPY_RAND_MIN } from '../../app/constants';
import { SDState } from "../../features/sd/sdSlice"; import { OptionsState } from '../../features/options/optionsSlice';
import { SystemState } from "../../features/system/systemSlice"; import { SystemState } from '../../features/system/systemSlice';
import randomInt from "./randomInt"; import {
import { seedWeightsToString, stringToSeedWeights } from "./seedWeightPairs"; seedWeightsToString,
stringToSeedWeightsArray,
} from './seedWeightPairs';
import randomInt from './randomInt';
export const frontendToBackendParameters = ( export const frontendToBackendParameters = (
sdState: SDState, optionsState: OptionsState,
systemState: SystemState systemState: SystemState
): { [key: string]: any } => { ): { [key: string]: any } => {
const { const {
prompt, prompt,
iterations, iterations,
steps, steps,
cfgScale, cfgScale,
height, height,
width, width,
sampler, sampler,
seed, seed,
seamless, seamless,
shouldUseInitImage, shouldUseInitImage,
img2imgStrength, img2imgStrength,
initialImagePath, initialImagePath,
maskPath, maskPath,
shouldFitToWidthHeight, shouldFitToWidthHeight,
shouldGenerateVariations, shouldGenerateVariations,
variationAmount, variationAmount,
seedWeights, seedWeights,
shouldRunESRGAN, shouldRunESRGAN,
upscalingLevel, upscalingLevel,
upscalingStrength, upscalingStrength,
shouldRunGFPGAN, shouldRunGFPGAN,
gfpganStrength, gfpganStrength,
shouldRandomizeSeed, shouldRandomizeSeed,
} = sdState; } = optionsState;
const { shouldDisplayInProgress } = systemState; const { shouldDisplayInProgress } = systemState;
const generationParameters: { [k: string]: any } = { const generationParameters: { [k: string]: any } = {
prompt, prompt,
iterations, iterations,
steps, steps,
cfg_scale: cfgScale, cfg_scale: cfgScale,
height, height,
width, width,
sampler_name: sampler, sampler_name: sampler,
seed, seed,
seamless, seamless,
progress_images: shouldDisplayInProgress, progress_images: shouldDisplayInProgress,
};
generationParameters.seed = shouldRandomizeSeed
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: seed;
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variationAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeightsArray(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
}; };
}
generationParameters.seed = shouldRandomizeSeed if (shouldRunGFPGAN) {
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) gfpganParameters = {
: seed; strength: gfpganStrength,
if (shouldUseInitImage) {
generationParameters.init_img = initialImagePath;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
if (maskPath) {
generationParameters.init_mask = maskPath;
}
}
if (shouldGenerateVariations) {
generationParameters.variation_amount = variationAmount;
if (seedWeights) {
generationParameters.with_variations =
stringToSeedWeights(seedWeights);
}
} else {
generationParameters.variation_amount = 0;
}
let esrganParameters: false | { [k: string]: any } = false;
let gfpganParameters: false | { [k: string]: any } = false;
if (shouldRunESRGAN) {
esrganParameters = {
level: upscalingLevel,
strength: upscalingStrength,
};
}
if (shouldRunGFPGAN) {
gfpganParameters = {
strength: gfpganStrength,
};
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
}; };
}
return {
generationParameters,
esrganParameters,
gfpganParameters,
};
}; };
export const backendToFrontendParameters = (parameters: { export const backendToFrontendParameters = (parameters: {
[key: string]: any; [key: string]: any;
}) => { }) => {
const { const {
prompt, prompt,
iterations, iterations,
steps, steps,
cfg_scale, cfg_scale,
height, height,
width, width,
sampler_name, sampler_name,
seed, seed,
seamless, seamless,
progress_images, progress_images,
variation_amount, variation_amount,
with_variations, with_variations,
gfpgan_strength, gfpgan_strength,
upscale, upscale,
init_img, init_img,
init_mask, init_mask,
strength, strength,
} = parameters; } = parameters;
const sd: { [key: string]: any } = { const options: { [key: string]: any } = {
shouldDisplayInProgress: progress_images, shouldDisplayInProgress: progress_images,
// init // init
shouldGenerateVariations: false, shouldGenerateVariations: false,
shouldRunESRGAN: false, shouldRunESRGAN: false,
shouldRunGFPGAN: false, shouldRunGFPGAN: false,
initialImagePath: '', initialImagePath: '',
maskPath: '', maskPath: '',
}; };
if (variation_amount > 0) { if (variation_amount > 0) {
sd.shouldGenerateVariations = true; options.shouldGenerateVariations = true;
sd.variationAmount = variation_amount; options.variationAmount = variation_amount;
if (with_variations) { if (with_variations) {
sd.seedWeights = seedWeightsToString(with_variations); options.seedWeights = seedWeightsToString(with_variations);
}
} }
}
if (gfpgan_strength > 0) { if (gfpgan_strength > 0) {
sd.shouldRunGFPGAN = true; options.shouldRunGFPGAN = true;
sd.gfpganStrength = gfpgan_strength; options.gfpganStrength = gfpgan_strength;
}
if (upscale) {
options.shouldRunESRGAN = true;
options.upscalingLevel = upscale[0];
options.upscalingStrength = upscale[1];
}
if (init_img) {
options.shouldUseInitImage = true;
options.initialImagePath = init_img;
options.strength = strength;
if (init_mask) {
options.maskPath = init_mask;
} }
}
if (upscale) { // if we had a prompt, add all the metadata, but if we don't have a prompt,
sd.shouldRunESRGAN = true; // we must have only done ESRGAN or GFPGAN so do not add that metadata
sd.upscalingLevel = upscale[0]; if (prompt) {
sd.upscalingStrength = upscale[1]; options.prompt = prompt;
} options.iterations = iterations;
options.steps = steps;
options.cfgScale = cfg_scale;
options.height = height;
options.width = width;
options.sampler = sampler_name;
options.seed = seed;
options.seamless = seamless;
}
if (init_img) { return options;
sd.shouldUseInitImage = true
sd.initialImagePath = init_img;
sd.strength = strength;
if (init_mask) {
sd.maskPath = init_mask;
}
}
// if we had a prompt, add all the metadata, but if we don't have a prompt,
// we must have only done ESRGAN or GFPGAN so do not add that metadata
if (prompt) {
sd.prompt = prompt;
sd.iterations = iterations;
sd.steps = steps;
sd.cfgScale = cfg_scale;
sd.height = height;
sd.width = width;
sd.sampler = sampler_name;
sd.seed = seed;
sd.seamless = seamless;
}
return sd;
}; };

View File

@ -0,0 +1,16 @@
import * as InvokeAI from '../../app/invokeai';
const promptToString = (prompt: InvokeAI.Prompt): string => {
if (prompt.length === 1) {
return prompt[0].prompt;
}
return prompt
.map(
(promptItem: InvokeAI.PromptItem): string =>
`${promptItem.prompt}:${promptItem.weight}`
)
.join(' ');
};
export default promptToString;

View File

@ -1,56 +1,68 @@
export interface SeedWeightPair { import * as InvokeAI from '../../app/invokeai';
seed: number;
weight: number;
}
export type SeedWeights = Array<Array<number>>; export const stringToSeedWeights = (
string: string
): InvokeAI.SeedWeights | boolean => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
const pairs = arrPairs.map((p: Array<string>): InvokeAI.SeedWeightPair => {
return { seed: parseInt(p[0]), weight: parseFloat(p[1]) };
});
export const stringToSeedWeights = (string: string): SeedWeights | boolean => { if (!validateSeedWeights(pairs)) {
const stringPairs = string.split(','); return false;
const arrPairs = stringPairs.map((p) => p.split(':')); }
const pairs = arrPairs.map((p) => {
return [parseInt(p[0]), parseFloat(p[1])];
});
if (!validateSeedWeights(pairs)) { return pairs;
return false;
}
return pairs;
}; };
export const validateSeedWeights = ( export const validateSeedWeights = (
seedWeights: SeedWeights | string seedWeights: InvokeAI.SeedWeights | string
): boolean => { ): boolean => {
return typeof seedWeights === 'string' return typeof seedWeights === 'string'
? Boolean(stringToSeedWeights(seedWeights)) ? Boolean(stringToSeedWeights(seedWeights))
: Boolean( : Boolean(
seedWeights.length && seedWeights.length &&
!seedWeights.some((pair) => { !seedWeights.some((pair: InvokeAI.SeedWeightPair) => {
const [seed, weight] = pair; const { seed, weight } = pair;
const isSeedValid = !isNaN(parseInt(seed.toString(), 10)); const isSeedValid = !isNaN(parseInt(seed.toString(), 10));
const isWeightValid = const isWeightValid =
!isNaN(parseInt(weight.toString(), 10)) && !isNaN(parseInt(weight.toString(), 10)) &&
weight >= 0 && weight >= 0 &&
weight <= 1; weight <= 1;
return !(isSeedValid && isWeightValid); return !(isSeedValid && isWeightValid);
}) })
); );
}; };
export const seedWeightsToString = ( export const seedWeightsToString = (
seedWeights: SeedWeights seedWeights: InvokeAI.SeedWeights
): string | boolean => { ): string => {
if (!validateSeedWeights(seedWeights)) { return seedWeights.reduce((acc, pair, i, arr) => {
return false; const { seed, weight } = pair;
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) {
acc += ',';
} }
return acc;
return seedWeights.reduce((acc, pair, i, arr) => { }, '');
const [seed, weight] = pair; };
acc += `${seed}:${weight}`;
if (i !== arr.length - 1) { export const seedWeightsToArray = (
acc += ','; seedWeights: InvokeAI.SeedWeights
} ): Array<Array<number>> => {
return acc; return seedWeights.map((pair: InvokeAI.SeedWeightPair) => [
}, ''); pair.seed,
pair.weight,
]);
};
export const stringToSeedWeightsArray = (
string: string
): Array<Array<number>> => {
const stringPairs = string.split(',');
const arrPairs = stringPairs.map((p) => p.split(':'));
return arrPairs.map(
(p: Array<string>): Array<number> => [parseInt(p[0]), parseFloat(p[1])]
);
}; };

View File

@ -1,12 +1,18 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setAllParameters, setInitialImagePath, setSeed } from '../sd/sdSlice'; import {
setAllParameters,
setInitialImagePath,
setSeed,
} from '../options/optionsSlice';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';
import { createSelector } from '@reduxjs/toolkit';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import { isEqual } from 'lodash';
import { SDImage } from './gallerySlice';
import SDButton from '../../common/components/SDButton'; import SDButton from '../../common/components/SDButton';
import { runESRGAN, runGFPGAN } from '../../app/socketio/actions'; import { runESRGAN, runGFPGAN } from '../../app/socketio/actions';
@ -28,7 +34,7 @@ const systemSelector = createSelector(
); );
type CurrentImageButtonsProps = { type CurrentImageButtonsProps = {
image: SDImage; image: InvokeAI.Image;
shouldShowImageDetails: boolean; shouldShowImageDetails: boolean;
setShouldShowImageDetails: (b: boolean) => void; setShouldShowImageDetails: (b: boolean) => void;
}; };
@ -49,7 +55,7 @@ const CurrentImageButtons = ({
); );
const { upscalingLevel, gfpganStrength } = useAppSelector( const { upscalingLevel, gfpganStrength } = useAppSelector(
(state: RootState) => state.sd (state: RootState) => state.options
); );
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } = const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
@ -63,8 +69,7 @@ const CurrentImageButtons = ({
// Non-null assertion: this button is disabled if there is no seed. // Non-null assertion: this button is disabled if there is no seed.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.seed!)); const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed));
const handleClickUpscale = () => dispatch(runESRGAN(image)); const handleClickUpscale = () => dispatch(runESRGAN(image));
const handleClickFixFaces = () => dispatch(runGFPGAN(image)); const handleClickFixFaces = () => dispatch(runGFPGAN(image));
@ -87,6 +92,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'} colorScheme={'gray'}
flexGrow={1} flexGrow={1}
variant={'outline'} variant={'outline'}
isDisabled={!['txt2img', 'img2img'].includes(image.metadata.image.type)}
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
/> />
@ -95,7 +101,7 @@ const CurrentImageButtons = ({
colorScheme={'gray'} colorScheme={'gray'}
flexGrow={1} flexGrow={1}
variant={'outline'} variant={'outline'}
isDisabled={!image.metadata.seed} isDisabled={!image.metadata.image.seed}
onClick={handleClickUseSeed} onClick={handleClickUseSeed}
/> />

View File

@ -17,6 +17,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { import {
ChangeEvent, ChangeEvent,
cloneElement, cloneElement,
forwardRef,
ReactElement, ReactElement,
SyntheticEvent, SyntheticEvent,
useRef, useRef,
@ -25,7 +26,7 @@ import { useAppDispatch, useAppSelector } from '../../app/store';
import { deleteImage } from '../../app/socketio/actions'; import { deleteImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice'; import { setShouldConfirmOnDelete, SystemState } from '../system/systemSlice';
import { SDImage } from './gallerySlice'; import * as InvokeAI from '../../app/invokeai';
interface DeleteImageModalProps { interface DeleteImageModalProps {
/** /**
@ -35,7 +36,7 @@ interface DeleteImageModalProps {
/** /**
* The image to delete. * The image to delete.
*/ */
image: SDImage; image: InvokeAI.Image;
} }
const systemSelector = createSelector( const systemSelector = createSelector(
@ -49,73 +50,76 @@ const systemSelector = createSelector(
* If it is false, the image is deleted immediately. * If it is false, the image is deleted immediately.
* The confirmation modal has a "Don't ask me again" switch to set the boolean. * The confirmation modal has a "Don't ask me again" switch to set the boolean.
*/ */
const DeleteImageModal = ({ image, children }: DeleteImageModalProps) => { const DeleteImageModal = forwardRef(
const { isOpen, onOpen, onClose } = useDisclosure(); ({ image, children }: DeleteImageModalProps, ref) => {
const dispatch = useAppDispatch(); const { isOpen, onOpen, onClose } = useDisclosure();
const shouldConfirmOnDelete = useAppSelector(systemSelector); const dispatch = useAppDispatch();
const cancelRef = useRef<HTMLButtonElement>(null); const shouldConfirmOnDelete = useAppSelector(systemSelector);
const cancelRef = useRef<HTMLButtonElement>(null);
const handleClickDelete = (e: SyntheticEvent) => { const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation(); e.stopPropagation();
shouldConfirmOnDelete ? onOpen() : handleDelete(); shouldConfirmOnDelete ? onOpen() : handleDelete();
}; };
const handleDelete = () => { const handleDelete = () => {
dispatch(deleteImage(image)); dispatch(deleteImage(image));
onClose(); onClose();
}; };
const handleChangeShouldConfirmOnDelete = ( const handleChangeShouldConfirmOnDelete = (
e: ChangeEvent<HTMLInputElement> e: ChangeEvent<HTMLInputElement>
) => dispatch(setShouldConfirmOnDelete(!e.target.checked)); ) => dispatch(setShouldConfirmOnDelete(!e.target.checked));
return ( return (
<> <>
{cloneElement(children, { {cloneElement(children, {
// TODO: This feels wrong. // TODO: This feels wrong.
onClick: handleClickDelete, onClick: handleClickDelete,
})} ref: ref,
})}
<AlertDialog <AlertDialog
isOpen={isOpen} isOpen={isOpen}
leastDestructiveRef={cancelRef} leastDestructiveRef={cancelRef}
onClose={onClose} onClose={onClose}
> >
<AlertDialogOverlay> <AlertDialogOverlay>
<AlertDialogContent> <AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold"> <AlertDialogHeader fontSize="lg" fontWeight="bold">
Delete image Delete image
</AlertDialogHeader> </AlertDialogHeader>
<AlertDialogBody> <AlertDialogBody>
<Flex direction={'column'} gap={5}> <Flex direction={'column'} gap={5}>
<Text> <Text>
Are you sure? You can't undo this action afterwards. Are you sure? You can't undo this action afterwards.
</Text> </Text>
<FormControl> <FormControl>
<Flex alignItems={'center'}> <Flex alignItems={'center'}>
<FormLabel mb={0}>Don't ask me again</FormLabel> <FormLabel mb={0}>Don't ask me again</FormLabel>
<Switch <Switch
checked={!shouldConfirmOnDelete} checked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete} onChange={handleChangeShouldConfirmOnDelete}
/> />
</Flex> </Flex>
</FormControl> </FormControl>
</Flex> </Flex>
</AlertDialogBody> </AlertDialogBody>
<AlertDialogFooter> <AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}> <Button ref={cancelRef} onClick={onClose}>
Cancel Cancel
</Button> </Button>
<Button colorScheme="red" onClick={handleDelete} ml={3}> <Button colorScheme="red" onClick={handleDelete} ml={3}>
Delete Delete
</Button> </Button>
</AlertDialogFooter> </AlertDialogFooter>
</AlertDialogContent> </AlertDialogContent>
</AlertDialogOverlay> </AlertDialogOverlay>
</AlertDialog> </AlertDialog>
</> </>
); );
}; }
);
export default DeleteImageModal; export default DeleteImageModal;

View File

@ -4,17 +4,20 @@ import {
Icon, Icon,
IconButton, IconButton,
Image, Image,
Tooltip,
useColorModeValue, useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from '../../app/store'; import { useAppDispatch } from '../../app/store';
import { SDImage, setCurrentImage } from './gallerySlice'; import { setCurrentImage } from './gallerySlice';
import { FaCheck, FaCopy, FaSeedling, FaTrash } from 'react-icons/fa'; import { FaCheck, FaSeedling, FaTrashAlt } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';
import { memo, SyntheticEvent, useState } from 'react'; import { memo, SyntheticEvent, useState } from 'react';
import { setAllParameters, setSeed } from '../sd/sdSlice'; import { setAllParameters, setSeed } from '../options/optionsSlice';
import * as InvokeAI from '../../app/invokeai';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
interface HoverableImageProps { interface HoverableImageProps {
image: SDImage; image: InvokeAI.Image;
isSelected: boolean; isSelected: boolean;
} }
@ -52,7 +55,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
e.stopPropagation(); e.stopPropagation();
// Non-null assertion: this button is not rendered unless this exists // Non-null assertion: this button is not rendered unless this exists
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
dispatch(setSeed(image.metadata.seed!)); dispatch(setSeed(image.metadata.image.seed));
}; };
const handleClickImage = () => dispatch(setCurrentImage(image)); const handleClickImage = () => dispatch(setCurrentImage(image));
@ -94,32 +97,41 @@ const HoverableImage = memo((props: HoverableImageProps) => {
top={1} top={1}
right={1} right={1}
> >
<DeleteImageModal image={image}> <Tooltip label={'Delete image'}>
<IconButton <DeleteImageModal image={image}>
colorScheme="red" <IconButton
aria-label="Delete image" colorScheme="red"
icon={<FaTrash />} aria-label="Delete image"
size="xs" icon={<FaTrashAlt />}
fontSize={15} size="xs"
/> variant={'imageHoverIconButton'}
</DeleteImageModal> fontSize={14}
<IconButton />
aria-label="Use all parameters" </DeleteImageModal>
colorScheme={'blue'} </Tooltip>
icon={<FaCopy />} {['txt2img', 'img2img'].includes(image.metadata.image.type) && (
size="xs" <Tooltip label="Use all parameters">
fontSize={15} <IconButton
onClickCapture={handleClickSetAllParameters} aria-label="Use all parameters"
/> icon={<IoArrowUndoCircleOutline />}
{image.metadata.seed && ( size="xs"
<IconButton fontSize={18}
aria-label="Use seed" variant={'imageHoverIconButton'}
colorScheme={'blue'} onClickCapture={handleClickSetAllParameters}
icon={<FaSeedling />} />
size="xs" </Tooltip>
fontSize={16} )}
onClickCapture={handleClickSetSeed} {image.metadata.image.seed && (
/> <Tooltip label="Use seed">
<IconButton
aria-label="Use seed"
icon={<FaSeedling />}
size="xs"
fontSize={16}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetSeed}
/>
</Tooltip>
)} )}
</Flex> </Flex>
)} )}

View File

@ -1,22 +1,82 @@
import { import {
Box,
Center, Center,
Flex, Flex,
IconButton, IconButton,
Link, Link,
List,
ListItem,
Text, Text,
Tooltip,
useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { memo } from 'react'; import { memo } from 'react';
import { FaPlus } from 'react-icons/fa'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { PARAMETERS } from '../../app/constants';
import { useAppDispatch } from '../../app/store'; import { useAppDispatch } from '../../app/store';
import SDButton from '../../common/components/SDButton'; import * as InvokeAI from '../../app/invokeai';
import { setAllParameters, setParameter } from '../sd/sdSlice'; import {
import { SDImage, SDMetadata } from './gallerySlice'; setCfgScale,
setGfpganStrength,
setHeight,
setImg2imgStrength,
setInitialImagePath,
setMaskPath,
setPrompt,
setSampler,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setUpscalingLevel,
setUpscalingStrength,
setWidth,
} from '../options/optionsSlice';
import promptToString from '../../common/util/promptToString';
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
import { FaCopy } from 'react-icons/fa';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({ label, value, onClick, isLink }: MetadataItemProps) => {
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label="Use this parameter"
icon={<IoArrowUndoCircleOutline />}
size={'xs'}
variant={'ghost'}
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
<Text fontWeight={'semibold'} whiteSpace={'nowrap'}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak={'break-all'}>
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}>
{value.toString()}
</Text>
)}
</Flex>
);
};
type ImageMetadataViewerProps = { type ImageMetadataViewerProps = {
image: SDImage; image: InvokeAI.Image;
}; };
// TODO: I don't know if this is needed. // TODO: I don't know if this is needed.
@ -33,91 +93,223 @@ const memoEqualityCheck = (
*/ */
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const jsonBgColor = useColorModeValue('blackAlpha.100', 'whiteAlpha.100');
/** const metadata = image.metadata.image;
* Build an array representing each item of metadata and a human-readable const {
* label for it e.g. "cfgScale" > "CFG Scale". type,
* postprocessing,
* This array is then used to render each item with a button to use that sampler,
* parameter in the processing settings. prompt,
* seed,
* TODO: All this logic feels sloppy. variations,
*/ steps,
const keys = Object.keys(PARAMETERS); cfg_scale,
seamless,
width,
height,
strength,
fit,
init_image_path,
mask_image_path,
orig_path,
scale,
} = metadata;
const metadata: Array<{ const metadataJSON = JSON.stringify(metadata, null, 2);
label: string;
key: string;
value: string | number | boolean;
}> = [];
keys.forEach((key) => {
const value = image.metadata[key as keyof SDMetadata];
if (value !== undefined) {
metadata.push({ label: PARAMETERS[key], key, value });
}
});
return ( return (
<Flex gap={2} direction={'column'} overflowY={'scroll'} width={'100%'}> <Flex
<SDButton gap={1}
label="Use all parameters" direction={'column'}
colorScheme={'gray'} overflowY={'scroll'}
padding={2} width={'100%'}
isDisabled={metadata.length === 0} >
onClick={() => dispatch(setAllParameters(image.metadata))}
/>
<Flex gap={2}> <Flex gap={2}>
<Text fontWeight={'semibold'}>File:</Text> <Text fontWeight={'semibold'}>File:</Text>
<Link href={image.url} isExternal> <Link href={image.url} isExternal>
<Text>{image.url}</Text> {image.url}
<ExternalLinkIcon mx="2px" />
</Link> </Link>
</Flex> </Flex>
{metadata.length ? ( {Object.keys(metadata).length ? (
<> <>
<List> {type && <MetadataItem label="Type" value={type} />}
{metadata.map((parameter, i) => { {['esrgan', 'gfpgan'].includes(type) && (
const { label, key, value } = parameter; <MetadataItem label="Original image" value={orig_path} isLink />
return ( )}
<ListItem key={i} pb={1}> {type === 'gfpgan' && strength && (
<Flex gap={2}> <MetadataItem
<IconButton label="Fix faces strength"
aria-label="Use this parameter" value={strength}
icon={<FaPlus />} onClick={() => dispatch(setGfpganStrength(strength))}
size={'xs'} />
onClick={() => )}
dispatch( {type === 'esrgan' && scale && (
setParameter({ <MetadataItem
key, label="Upscaling scale"
value, value={scale}
}) onClick={() => dispatch(setUpscalingLevel(scale))}
) />
} )}
{type === 'esrgan' && strength && (
<MetadataItem
label="Upscaling strength"
value={strength}
onClick={() => dispatch(setUpscalingStrength(strength))}
/>
)}
{prompt && (
<MetadataItem
label="Prompt"
value={promptToString(prompt)}
onClick={() => dispatch(setPrompt(prompt))}
/>
)}
{seed && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
/>
)}
{sampler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
/>
)}
{steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
/>
)}
{cfg_scale && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
/>
)}
{variations && variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
}
/>
)}
{seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setWidth(seamless))}
/>
)}
{width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImagePath(init_image_path))}
/>
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing &&
postprocessing.length > 0 &&
postprocessing.map(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
return (
<>
<MetadataItem
label="Upscaling scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Upscaling strength"
value={strength}
onClick={() => dispatch(setUpscalingStrength(strength))}
/>
</>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<MetadataItem
label="Fix faces strength"
value={strength}
onClick={() => dispatch(setGfpganStrength(strength))}
/> />
<Text fontWeight={'semibold'}>{label}:</Text> );
}
{value === undefined || }
value === null || )}
value === '' || <Flex gap={2} direction={'column'}>
value === 0 ? ( <Flex gap={2}>
<Text maxHeight={100} fontStyle={'italic'}> <Tooltip label={`Copy JSON`}>
None <IconButton
</Text> aria-label="Copy JSON"
) : ( icon={<FaCopy />}
<Text maxHeight={100} overflowY={'scroll'}> size={'xs'}
{value.toString()} variant={'ghost'}
</Text> fontSize={14}
)} onClick={() => navigator.clipboard.writeText(metadataJSON)}
</Flex> />
</ListItem> </Tooltip>
); <Text fontWeight={'semibold'}>JSON:</Text>
})} </Flex>
</List> <Box
<Flex gap={2}> // maxHeight={200}
<Text fontWeight={'semibold'}>Raw:</Text> overflow={'scroll'}
<Text maxHeight={100} overflowY={'scroll'} wordBreak={'break-all'}> flexGrow={3}
{JSON.stringify(image.metadata)} wordBreak={'break-all'}
</Text> bgColor={jsonBgColor}
padding={2}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex> </Flex>
</> </>
) : ( ) : (

View File

@ -1,39 +1,13 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { UpscalingLevel } from '../sd/sdSlice';
import { clamp } from 'lodash'; import { clamp } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
// TODO: Revise pending metadata RFC: https://github.com/lstein/stable-diffusion/issues/266
export interface SDMetadata {
prompt?: string;
steps?: number;
cfgScale?: number;
height?: number;
width?: number;
sampler?: string;
seed?: number;
img2imgStrength?: number;
gfpganStrength?: number;
upscalingLevel?: UpscalingLevel;
upscalingStrength?: number;
initialImagePath?: string;
maskPath?: string;
seamless?: boolean;
shouldFitToWidthHeight?: boolean;
}
export interface SDImage {
// TODO: I have installed @types/uuid but cannot figure out how to use them here.
uuid: string;
url: string;
metadata: SDMetadata;
}
export interface GalleryState { export interface GalleryState {
currentImage?: InvokeAI.Image;
currentImageUuid: string; currentImageUuid: string;
images: Array<SDImage>; images: Array<InvokeAI.Image>;
intermediateImage?: SDImage; intermediateImage?: InvokeAI.Image;
currentImage?: SDImage;
} }
const initialState: GalleryState = { const initialState: GalleryState = {
@ -45,7 +19,7 @@ export const gallerySlice = createSlice({
name: 'gallery', name: 'gallery',
initialState, initialState,
reducers: { reducers: {
setCurrentImage: (state, action: PayloadAction<SDImage>) => { setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.currentImage = action.payload; state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid; state.currentImageUuid = action.payload.uuid;
}, },
@ -92,19 +66,19 @@ export const gallerySlice = createSlice({
state.images = newImages; state.images = newImages;
}, },
addImage: (state, action: PayloadAction<SDImage>) => { addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.images.push(action.payload); state.images.push(action.payload);
state.currentImageUuid = action.payload.uuid; state.currentImageUuid = action.payload.uuid;
state.intermediateImage = undefined; state.intermediateImage = undefined;
state.currentImage = action.payload; state.currentImage = action.payload;
}, },
setIntermediateImage: (state, action: PayloadAction<SDImage>) => { setIntermediateImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.intermediateImage = action.payload; state.intermediateImage = action.payload;
}, },
clearIntermediateImage: (state) => { clearIntermediateImage: (state) => {
state.intermediateImage = undefined; state.intermediateImage = undefined;
}, },
setGalleryImages: (state, action: PayloadAction<Array<SDImage>>) => { setGalleryImages: (state, action: PayloadAction<Array<InvokeAI.Image>>) => {
const newImages = action.payload; const newImages = action.payload;
if (newImages.length) { if (newImages.length) {
const newCurrentImage = newImages[newImages.length - 1]; const newCurrentImage = newImages[newImages.length - 1];
@ -117,12 +91,12 @@ export const gallerySlice = createSlice({
}); });
export const { export const {
setCurrentImage,
removeImage,
addImage, addImage,
clearIntermediateImage,
removeImage,
setCurrentImage,
setGalleryImages, setGalleryImages,
setIntermediateImage, setIntermediateImage,
clearIntermediateImage,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;

View File

@ -7,8 +7,8 @@ import {
setUpscalingLevel, setUpscalingLevel,
setUpscalingStrength, setUpscalingStrength,
UpscalingLevel, UpscalingLevel,
SDState, OptionsState,
} from '../sd/sdSlice'; } from '../options/optionsSlice';
import { UPSCALING_LEVELS } from '../../app/constants'; import { UPSCALING_LEVELS } from '../../app/constants';
@ -19,12 +19,12 @@ import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput'; import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect'; import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
upscalingLevel: sd.upscalingLevel, upscalingLevel: options.upscalingLevel,
upscalingStrength: sd.upscalingStrength, upscalingStrength: options.upscalingStrength,
}; };
}, },
{ {
@ -53,7 +53,7 @@ const systemSelector = createSelector(
*/ */
const ESRGANOptions = () => { const ESRGANOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { upscalingLevel, upscalingStrength } = useAppSelector(sdSelector); const { upscalingLevel, upscalingStrength } = useAppSelector(optionsSelector);
const { isESRGANAvailable } = useAppSelector(systemSelector); const { isESRGANAvailable } = useAppSelector(systemSelector);
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) => const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>

View File

@ -3,7 +3,7 @@ import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { SDState, setGfpganStrength } from '../sd/sdSlice'; import { OptionsState, setGfpganStrength } from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
@ -11,11 +11,11 @@ import { isEqual } from 'lodash';
import { SystemState } from '../system/systemSlice'; import { SystemState } from '../system/systemSlice';
import SDNumberInput from '../../common/components/SDNumberInput'; import SDNumberInput from '../../common/components/SDNumberInput';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
gfpganStrength: sd.gfpganStrength, gfpganStrength: options.gfpganStrength,
}; };
}, },
{ {
@ -44,7 +44,7 @@ const systemSelector = createSelector(
*/ */
const GFPGANOptions = () => { const GFPGANOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { gfpganStrength } = useAppSelector(sdSelector); const { gfpganStrength } = useAppSelector(optionsSelector);
const { isGFPGANAvailable } = useAppSelector(systemSelector); const { isGFPGANAvailable } = useAppSelector(systemSelector);
const handleChangeStrength = (v: string | number) => const handleChangeStrength = (v: string | number) =>

View File

@ -7,17 +7,17 @@ import SDNumberInput from '../../common/components/SDNumberInput';
import SDSwitch from '../../common/components/SDSwitch'; import SDSwitch from '../../common/components/SDSwitch';
import InitAndMaskImage from './InitAndMaskImage'; import InitAndMaskImage from './InitAndMaskImage';
import { import {
SDState, OptionsState,
setImg2imgStrength, setImg2imgStrength,
setShouldFitToWidthHeight, setShouldFitToWidthHeight,
} from './sdSlice'; } from './optionsSlice';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
img2imgStrength: sd.img2imgStrength, img2imgStrength: options.img2imgStrength,
shouldFitToWidthHeight: sd.shouldFitToWidthHeight, shouldFitToWidthHeight: options.shouldFitToWidthHeight,
}; };
} }
); );
@ -28,7 +28,7 @@ const sdSelector = createSelector(
const ImageToImageOptions = () => { const ImageToImageOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { img2imgStrength, shouldFitToWidthHeight } = const { img2imgStrength, shouldFitToWidthHeight } =
useAppSelector(sdSelector); useAppSelector(optionsSelector);
const handleChangeStrength = (v: string | number) => const handleChangeStrength = (v: string | number) =>
dispatch(setImg2imgStrength(Number(v))); dispatch(setImg2imgStrength(Number(v)));

View File

@ -1,3 +1,4 @@
import { Box } from '@chakra-ui/react';
import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react'; import { cloneElement, ReactElement, SyntheticEvent, useCallback } from 'react';
import { FileRejection, useDropzone } from 'react-dropzone'; import { FileRejection, useDropzone } from 'react-dropzone';
@ -51,12 +52,12 @@ const ImageUploader = ({
}; };
return ( return (
<div {...getRootProps()}> <Box {...getRootProps()} flexGrow={3}>
<input {...getInputProps({ multiple: false })} /> <input {...getInputProps({ multiple: false })} />
{cloneElement(children, { {cloneElement(children, {
onClick: handleClickUploadIcon, onClick: handleClickUploadIcon,
})} })}
</div> </Box>
); );
}; };

View File

@ -2,18 +2,18 @@ import { Flex, Image } from '@chakra-ui/react';
import { useState } from 'react'; import { useState } from 'react';
import { useAppSelector } from '../../app/store'; import { useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { SDState } from '../../features/sd/sdSlice'; import { OptionsState } from '../../features/options/optionsSlice';
import './InitAndMaskImage.css'; import './InitAndMaskImage.css';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import InitAndMaskUploadButtons from './InitAndMaskUploadButtons'; import InitAndMaskUploadButtons from './InitAndMaskUploadButtons';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
initialImagePath: sd.initialImagePath, initialImagePath: options.initialImagePath,
maskPath: sd.maskPath, maskPath: options.maskPath,
}; };
}, },
{ memoizeOptions: { resultEqualityCheck: isEqual } } { memoizeOptions: { resultEqualityCheck: isEqual } }
@ -23,7 +23,7 @@ const sdSelector = createSelector(
* Displays init and mask images and buttons to upload/delete them. * Displays init and mask images and buttons to upload/delete them.
*/ */
const InitAndMaskImage = () => { const InitAndMaskImage = () => {
const { initialImagePath, maskPath } = useAppSelector(sdSelector); const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
const [shouldShowMask, setShouldShowMask] = useState<boolean>(false); const [shouldShowMask, setShouldShowMask] = useState<boolean>(false);
return ( return (

View File

@ -1,25 +1,28 @@
import { Button, Flex, IconButton, useToast } from '@chakra-ui/react'; import { Button, Flex, IconButton, useToast } from '@chakra-ui/react';
import { SyntheticEvent, useCallback } from 'react'; import { SyntheticEvent, useCallback } from 'react';
import { FaTrash } from 'react-icons/fa'; import { FaTrash, FaUpload } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { import {
SDState, OptionsState,
setInitialImagePath, setInitialImagePath,
setMaskPath, setMaskPath,
} from '../../features/sd/sdSlice'; } from '../../features/options/optionsSlice';
import { uploadInitialImage, uploadMaskImage } from '../../app/socketio/actions'; import {
uploadInitialImage,
uploadMaskImage,
} from '../../app/socketio/actions';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import ImageUploader from './ImageUploader'; import ImageUploader from './ImageUploader';
import { FileRejection } from 'react-dropzone'; import { FileRejection } from 'react-dropzone';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
initialImagePath: sd.initialImagePath, initialImagePath: options.initialImagePath,
maskPath: sd.maskPath, maskPath: options.maskPath,
}; };
}, },
{ memoizeOptions: { resultEqualityCheck: isEqual } } { memoizeOptions: { resultEqualityCheck: isEqual } }
@ -36,15 +39,20 @@ const InitAndMaskUploadButtons = ({
setShouldShowMask, setShouldShowMask,
}: InitAndMaskUploadButtonsProps) => { }: InitAndMaskUploadButtonsProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { initialImagePath } = useAppSelector(sdSelector); const { initialImagePath, maskPath } = useAppSelector(optionsSelector);
// Use a toast to alert user when a file upload is rejected // Use a toast to alert user when a file upload is rejected
const toast = useToast(); const toast = useToast();
// Clear the init and mask images // Clear the init and mask images
const handleClickResetInitialImageAndMask = (e: SyntheticEvent) => { const handleClickResetInitialImage = (e: SyntheticEvent) => {
e.stopPropagation(); e.stopPropagation();
dispatch(setInitialImagePath('')); dispatch(setInitialImagePath(''));
};
// Clear the init and mask images
const handleClickResetMask = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setMaskPath('')); dispatch(setMaskPath(''));
}; };
@ -96,11 +104,21 @@ const InitAndMaskUploadButtons = ({
fontWeight={'normal'} fontWeight={'normal'}
onMouseOver={handleMouseOverInitialImageUploadButton} onMouseOver={handleMouseOverInitialImageUploadButton}
onMouseOut={handleMouseOutInitialImageUploadButton} onMouseOut={handleMouseOutInitialImageUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
> >
Upload Image Image
</Button> </Button>
</ImageUploader> </ImageUploader>
<IconButton
isDisabled={!initialImagePath}
size={'sm'}
aria-label={'Reset mask'}
onClick={handleClickResetInitialImage}
icon={<FaTrash />}
/>
<ImageUploader <ImageUploader
fileAcceptedCallback={maskImageFileAcceptedCallback} fileAcceptedCallback={maskImageFileAcceptedCallback}
fileRejectionCallback={fileRejectionCallback} fileRejectionCallback={fileRejectionCallback}
@ -112,16 +130,18 @@ const InitAndMaskUploadButtons = ({
fontWeight={'normal'} fontWeight={'normal'}
onMouseOver={handleMouseOverMaskUploadButton} onMouseOver={handleMouseOverMaskUploadButton}
onMouseOut={handleMouseOutMaskUploadButton} onMouseOut={handleMouseOutMaskUploadButton}
leftIcon={<FaUpload />}
width={'100%'}
> >
Upload Mask Mask
</Button> </Button>
</ImageUploader> </ImageUploader>
<IconButton <IconButton
isDisabled={!initialImagePath} isDisabled={!maskPath}
size={'sm'} size={'sm'}
aria-label={'Reset initial image and mask'} aria-label={'Reset mask'}
onClick={handleClickResetInitialImageAndMask} onClick={handleClickResetMask}
icon={<FaTrash />} icon={<FaTrash />}
/> />
</Flex> </Flex>

View File

@ -17,9 +17,9 @@ import { useAppDispatch, useAppSelector } from '../../app/store';
import { import {
setShouldRunGFPGAN, setShouldRunGFPGAN,
setShouldRunESRGAN, setShouldRunESRGAN,
SDState, OptionsState,
setShouldUseInitImage, setShouldUseInitImage,
} from '../sd/sdSlice'; } from '../options/optionsSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash'; import { isEqual } from 'lodash';
import { setOpenAccordions, SystemState } from '../system/systemSlice'; import { setOpenAccordions, SystemState } from '../system/systemSlice';
@ -31,14 +31,14 @@ import OutputOptions from './OutputOptions';
import ImageToImageOptions from './ImageToImageOptions'; import ImageToImageOptions from './ImageToImageOptions';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
initialImagePath: sd.initialImagePath, initialImagePath: options.initialImagePath,
shouldUseInitImage: sd.shouldUseInitImage, shouldUseInitImage: options.shouldUseInitImage,
shouldRunESRGAN: sd.shouldRunESRGAN, shouldRunESRGAN: options.shouldRunESRGAN,
shouldRunGFPGAN: sd.shouldRunGFPGAN, shouldRunGFPGAN: options.shouldRunGFPGAN,
}; };
}, },
{ {
@ -73,7 +73,7 @@ const OptionsAccordion = () => {
shouldRunGFPGAN, shouldRunGFPGAN,
shouldUseInitImage, shouldUseInitImage,
initialImagePath, initialImagePath,
} = useAppSelector(sdSelector); } = useAppSelector(optionsSelector);
const { isGFPGANAvailable, isESRGANAvailable, openAccordions } = const { isGFPGANAvailable, isESRGANAvailable, openAccordions } =
useAppSelector(systemSelector); useAppSelector(systemSelector);

View File

@ -3,7 +3,7 @@ import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { setHeight, setWidth, setSeamless, SDState } from '../sd/sdSlice'; import { setHeight, setWidth, setSeamless, OptionsState } from '../options/optionsSlice';
import { HEIGHTS, WIDTHS } from '../../app/constants'; import { HEIGHTS, WIDTHS } from '../../app/constants';
@ -13,13 +13,13 @@ import { ChangeEvent } from 'react';
import SDSelect from '../../common/components/SDSelect'; import SDSelect from '../../common/components/SDSelect';
import SDSwitch from '../../common/components/SDSwitch'; import SDSwitch from '../../common/components/SDSwitch';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
height: sd.height, height: options.height,
width: sd.width, width: options.width,
seamless: sd.seamless, seamless: options.seamless,
}; };
}, },
{ {
@ -34,7 +34,7 @@ const sdSelector = createSelector(
*/ */
const OutputOptions = () => { const OutputOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { height, width, seamless } = useAppSelector(sdSelector); const { height, width, seamless } = useAppSelector(optionsSelector);
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) => const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setWidth(Number(e.target.value))); dispatch(setWidth(Number(e.target.value)));

View File

@ -6,13 +6,13 @@ import {
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { generateImage } from '../../app/socketio/actions'; import { generateImage } from '../../app/socketio/actions';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { setPrompt } from '../sd/sdSlice'; import { setPrompt } from '../options/optionsSlice';
/** /**
* Prompt input text area. * Prompt input text area.
*/ */
const PromptInput = () => { const PromptInput = () => {
const { prompt } = useAppSelector((state: RootState) => state.sd); const { prompt } = useAppSelector((state: RootState) => state.options);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) =>

View File

@ -3,7 +3,7 @@ import { Flex } from '@chakra-ui/react';
import { RootState } from '../../app/store'; import { RootState } from '../../app/store';
import { useAppDispatch, useAppSelector } from '../../app/store'; import { useAppDispatch, useAppSelector } from '../../app/store';
import { setCfgScale, setSampler, setSteps, SDState } from '../sd/sdSlice'; import { setCfgScale, setSampler, setSteps, OptionsState } from '../options/optionsSlice';
import { SAMPLERS } from '../../app/constants'; import { SAMPLERS } from '../../app/constants';
@ -13,13 +13,13 @@ import { ChangeEvent } from 'react';
import SDNumberInput from '../../common/components/SDNumberInput'; import SDNumberInput from '../../common/components/SDNumberInput';
import SDSelect from '../../common/components/SDSelect'; import SDSelect from '../../common/components/SDSelect';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
steps: sd.steps, steps: options.steps,
cfgScale: sd.cfgScale, cfgScale: options.cfgScale,
sampler: sd.sampler, sampler: options.sampler,
}; };
}, },
{ {
@ -34,7 +34,7 @@ const sdSelector = createSelector(
*/ */
const SamplerOptions = () => { const SamplerOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { steps, cfgScale, sampler } = useAppSelector(sdSelector); const { steps, cfgScale, sampler } = useAppSelector(optionsSelector);
const handleChangeSteps = (v: string | number) => const handleChangeSteps = (v: string | number) =>
dispatch(setSteps(Number(v))); dispatch(setSteps(Number(v)));

View File

@ -18,25 +18,25 @@ import SDSwitch from '../../common/components/SDSwitch';
import randomInt from '../../common/util/randomInt'; import randomInt from '../../common/util/randomInt';
import { validateSeedWeights } from '../../common/util/seedWeightPairs'; import { validateSeedWeights } from '../../common/util/seedWeightPairs';
import { import {
SDState, OptionsState,
setIterations, setIterations,
setSeed, setSeed,
setSeedWeights, setSeedWeights,
setShouldGenerateVariations, setShouldGenerateVariations,
setShouldRandomizeSeed, setShouldRandomizeSeed,
setVariationAmount, setVariationAmount,
} from './sdSlice'; } from './optionsSlice';
const sdSelector = createSelector( const optionsSelector = createSelector(
(state: RootState) => state.sd, (state: RootState) => state.options,
(sd: SDState) => { (options: OptionsState) => {
return { return {
variationAmount: sd.variationAmount, variationAmount: options.variationAmount,
seedWeights: sd.seedWeights, seedWeights: options.seedWeights,
shouldGenerateVariations: sd.shouldGenerateVariations, shouldGenerateVariations: options.shouldGenerateVariations,
shouldRandomizeSeed: sd.shouldRandomizeSeed, shouldRandomizeSeed: options.shouldRandomizeSeed,
seed: sd.seed, seed: options.seed,
iterations: sd.iterations, iterations: options.iterations,
}; };
}, },
{ {
@ -57,7 +57,7 @@ const SeedVariationOptions = () => {
shouldRandomizeSeed, shouldRandomizeSeed,
seed, seed,
iterations, iterations,
} = useAppSelector(sdSelector); } = useAppSelector(optionsSelector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();

View File

@ -1,10 +1,12 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { SDMetadata } from '../gallery/gallerySlice'; import * as InvokeAI from '../../app/invokeai';
import promptToString from '../../common/util/promptToString';
import { seedWeightsToString } from '../../common/util/seedWeightPairs';
export type UpscalingLevel = 2 | 4; export type UpscalingLevel = 2 | 4;
export interface SDState { export interface OptionsState {
prompt: string; prompt: string;
iterations: number; iterations: number;
steps: number; steps: number;
@ -30,7 +32,7 @@ export interface SDState {
shouldRandomizeSeed: boolean; shouldRandomizeSeed: boolean;
} }
const initialSDState: SDState = { const initialOptionsState: OptionsState = {
prompt: '', prompt: '',
iterations: 1, iterations: 1,
steps: 50, steps: 50,
@ -56,14 +58,19 @@ const initialSDState: SDState = {
shouldRandomizeSeed: true, shouldRandomizeSeed: true,
}; };
const initialState: SDState = initialSDState; const initialState: OptionsState = initialOptionsState;
export const sdSlice = createSlice({ export const optionsSlice = createSlice({
name: 'sd', name: 'options',
initialState, initialState,
reducers: { reducers: {
setPrompt: (state, action: PayloadAction<string>) => { setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
state.prompt = action.payload; const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.prompt = newPrompt;
} else {
state.prompt = promptToString(newPrompt);
}
}, },
setIterations: (state, action: PayloadAction<number>) => { setIterations: (state, action: PayloadAction<number>) => {
state.iterations = action.payload; state.iterations = action.payload;
@ -143,65 +150,89 @@ export const sdSlice = createSlice({
setSeedWeights: (state, action: PayloadAction<string>) => { setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload; state.seedWeights = action.payload;
}, },
setAllParameters: (state, action: PayloadAction<SDMetadata>) => { setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
// TODO: This probably needs to be refactored.
const { const {
prompt, type,
steps, postprocessing,
cfgScale,
height,
width,
sampler, sampler,
prompt,
seed, seed,
img2imgStrength, variations,
gfpganStrength, steps,
upscalingLevel, cfg_scale,
upscalingStrength,
initialImagePath,
maskPath,
seamless, seamless,
shouldFitToWidthHeight, width,
} = action.payload; height,
strength,
fit,
init_image_path,
mask_image_path,
} = action.payload.image;
// ?? = falsy values ('', 0, etc) are used if (type === 'img2img') {
// || = falsy values not used if (init_image_path) state.initialImagePath = init_image_path;
state.prompt = prompt ?? state.prompt; if (mask_image_path) state.maskPath = mask_image_path;
state.steps = steps || state.steps; if (strength) state.img2imgStrength = strength;
state.cfgScale = cfgScale || state.cfgScale; if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.width = width || state.width; state.shouldUseInitImage = true;
state.height = height || state.height; } else {
state.sampler = sampler || state.sampler; state.shouldUseInitImage = false;
state.seed = seed ?? state.seed; }
state.seamless = seamless ?? state.seamless;
state.shouldFitToWidthHeight = if (variations && variations.length > 0) {
shouldFitToWidthHeight ?? state.shouldFitToWidthHeight; state.seedWeights = seedWeightsToString(variations);
state.img2imgStrength = img2imgStrength ?? state.img2imgStrength; state.shouldGenerateVariations = true;
state.gfpganStrength = gfpganStrength ?? state.gfpganStrength; } else {
state.upscalingLevel = upscalingLevel ?? state.upscalingLevel; state.shouldGenerateVariations = false;
state.upscalingStrength = upscalingStrength ?? state.upscalingStrength; }
state.initialImagePath = initialImagePath ?? state.initialImagePath;
state.maskPath = maskPath ?? state.maskPath;
// If the image whose parameters we are using has a seed, disable randomizing the seed
if (seed) { if (seed) {
state.seed = seed;
state.shouldRandomizeSeed = false; state.shouldRandomizeSeed = false;
} }
// if we have a gfpgan strength, enable it let postprocessingNotDone = ['gfpgan', 'esrgan'];
state.shouldRunGFPGAN = gfpganStrength ? true : false; if (postprocessing && postprocessing.length > 0) {
postprocessing.forEach(
(postprocess: InvokeAI.PostProcessedImageMetadata) => {
if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
if (strength) state.gfpganStrength = strength;
state.shouldRunGFPGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'gfpgan'
);
}
if (postprocess.type === 'esrgan') {
const { scale, strength } = postprocess;
if (scale) state.upscalingLevel = scale;
if (strength) state.upscalingStrength = strength;
state.shouldRunESRGAN = true;
postprocessingNotDone = postprocessingNotDone.filter(
(p) => p !== 'esrgan'
);
}
}
);
}
// if we have a esrgan strength, enable it postprocessingNotDone.forEach((p) => {
state.shouldRunESRGAN = upscalingLevel ? true : false; if (p === 'esrgan') state.shouldRunESRGAN = false;
if (p === 'gfpgan') state.shouldRunGFPGAN = false;
});
// if we want to recreate an image exactly, we disable variations if (prompt) state.prompt = promptToString(prompt);
state.shouldGenerateVariations = false; if (sampler) state.sampler = sampler;
if (steps) state.steps = steps;
state.shouldUseInitImage = initialImagePath ? true : false; if (cfg_scale) state.cfgScale = cfg_scale;
if (typeof seamless === 'boolean') state.seamless = seamless;
if (width) state.width = width;
if (height) state.height = height;
}, },
resetSDState: (state) => { resetOptionsState: (state) => {
return { return {
...state, ...state,
...initialSDState, ...initialOptionsState,
}; };
}, },
setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => { setShouldRunGFPGAN: (state, action: PayloadAction<boolean>) => {
@ -234,7 +265,7 @@ export const {
setInitialImagePath, setInitialImagePath,
setMaskPath, setMaskPath,
resetSeed, resetSeed,
resetSDState, resetOptionsState,
setShouldFitToWidthHeight, setShouldFitToWidthHeight,
setParameter, setParameter,
setShouldGenerateVariations, setShouldGenerateVariations,
@ -244,6 +275,6 @@ export const {
setShouldRunGFPGAN, setShouldRunGFPGAN,
setShouldRunESRGAN, setShouldRunESRGAN,
setShouldRandomizeSeed, setShouldRandomizeSeed,
} = sdSlice.actions; } = optionsSlice.actions;
export default sdSlice.reducer; export default optionsSlice.reducer;

View File

@ -62,7 +62,7 @@ const SiteHeader = () => {
return ( return (
<Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}> <Flex minWidth="max-content" alignItems="center" gap="1" pl={2} pr={1}>
<Heading size={'lg'}>Stable Diffusion Dream Server</Heading> <Heading size={'lg'}>InvokeUI</Heading>
<Spacer /> <Spacer />

View File

@ -1,6 +1,7 @@
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { ExpandedIndex } from '@chakra-ui/react'; import { ExpandedIndex } from '@chakra-ui/react';
import * as InvokeAI from '../../app/invokeai'
export type LogLevel = 'info' | 'warning' | 'error'; export type LogLevel = 'info' | 'warning' | 'error';
@ -14,17 +15,7 @@ export interface Log {
[index: number]: LogEntry; [index: number]: LogEntry;
} }
export interface SystemStatus { export interface SystemState extends InvokeAI.SystemStatus, InvokeAI.SystemConfig {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
}
export interface SystemState extends SystemStatus {
shouldDisplayInProgress: boolean; shouldDisplayInProgress: boolean;
log: Array<LogEntry>; log: Array<LogEntry>;
shouldShowLogViewer: boolean; shouldShowLogViewer: boolean;
@ -59,6 +50,11 @@ const initialSystemState = {
totalIterations: 0, totalIterations: 0,
currentStatus: '', currentStatus: '',
currentStatusHasSteps: false, currentStatusHasSteps: false,
model: '',
model_id: '',
model_hash: '',
app_id: '',
app_version: '',
}; };
const initialState: SystemState = initialSystemState; const initialState: SystemState = initialSystemState;
@ -76,7 +72,7 @@ export const systemSlice = createSlice({
setCurrentStatus: (state, action: PayloadAction<string>) => { setCurrentStatus: (state, action: PayloadAction<string>) => {
state.currentStatus = action.payload; state.currentStatus = action.payload;
}, },
setSystemStatus: (state, action: PayloadAction<SystemStatus>) => { setSystemStatus: (state, action: PayloadAction<InvokeAI.SystemStatus>) => {
const currentStatus = const currentStatus =
!action.payload.isProcessing && state.isConnected !action.payload.isProcessing && state.isConnected
? 'Connected' ? 'Connected'
@ -118,6 +114,9 @@ export const systemSlice = createSlice({
setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => { setOpenAccordions: (state, action: PayloadAction<ExpandedIndex>) => {
state.openAccordions = action.payload; state.openAccordions = action.payload;
}, },
setSystemConfig: (state, action: PayloadAction<InvokeAI.SystemConfig>) => {
return { ...state, ...action.payload };
},
}, },
}); });
@ -132,6 +131,7 @@ export const {
setOpenAccordions, setOpenAccordions,
setSystemStatus, setSystemStatus,
setCurrentStatus, setCurrentStatus,
setSystemConfig,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;

View File

@ -419,6 +419,13 @@
compute-scroll-into-view "1.0.14" compute-scroll-into-view "1.0.14"
copy-to-clipboard "3.3.1" copy-to-clipboard "3.3.1"
"@chakra-ui/icon@3.0.10":
version "3.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.10.tgz#1a11b5edb42a8af7aa5b6dec2bf2c6c4df1869fc"
integrity sha512-utO569d9bptEraJrEhuImfNzQ8v+a8PsQh8kTsodCzg8B16R3t5TTuoqeJqS6Nq16Vq6w87QbX3/4A73CNK5fw==
dependencies:
"@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icon@3.0.9": "@chakra-ui/icon@3.0.9":
version "3.0.9" version "3.0.9"
resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744" resolved "https://registry.yarnpkg.com/@chakra-ui/icon/-/icon-3.0.9.tgz#ba127d9eefd727f62e9bce07a23eca39ae506744"
@ -426,6 +433,13 @@
dependencies: dependencies:
"@chakra-ui/shared-utils" "2.0.1" "@chakra-ui/shared-utils" "2.0.1"
"@chakra-ui/icons@^2.0.10":
version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/icons/-/icons-2.0.10.tgz#61aeb44c913c10e7ff77addc798494e50d66c760"
integrity sha512-hxMspvysOay2NsJyadM611F/Y4vVzJU/YkXTxsyBjm6v/DbENhpVmPnUf+kwwyl7dINNb9iOF+kuGxnuIEO1Tw==
dependencies:
"@chakra-ui/icon" "3.0.10"
"@chakra-ui/image@2.0.10": "@chakra-ui/image@2.0.10":
version "2.0.10" version "2.0.10"
resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8" resolved "https://registry.yarnpkg.com/@chakra-ui/image/-/image-2.0.10.tgz#712c0e1c579d959225bd8316d8d8f66cbeb95bb8"

View File

@ -174,31 +174,37 @@ class Args(object):
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}') switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}') switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'-A {a["sampler_name"]}')
if a['grid']: if a['grid']:
switches.append('--grid') switches.append('--grid')
if a['seamless']: if a['seamless']:
switches.append('--seamless') switches.append('--seamless')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0: if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}') switches.append(f'-I {a["init_img"]}')
if a['init_mask'] and len(a['init_mask'])>0: switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
switches.append(f'-M {a["init_mask"]}') if a['fit']:
if a['init_color'] and len(a['init_color'])>0: switches.append(f'--fit')
switches.append(f'--init_color {a["init_color"]}') if a['init_mask'] and len(a['init_mask'])>0:
if a['fit']: switches.append(f'-M {a["init_mask"]}')
switches.append(f'--fit') if a['init_color'] and len(a['init_color'])>0:
if a['init_img'] and a['strength'] and a['strength']>0: switches.append(f'--init_color {a["init_color"]}')
switches.append(f'-f {a["strength"]}') if a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
else:
switches.append(f'-A {a["sampler_name"]}')
# gfpgan-specific parameters
if a['gfpgan_strength']: if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}') switches.append(f'-G {a["gfpgan_strength"]}')
# esrgan-specific parameters
if a['upscale']: if a['upscale']:
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}') switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
if a['embiggen']: if a['embiggen']:
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}') switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
if a['embiggen_tiles']: if a['embiggen_tiles']:
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}') switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
if a['variation_amount'] > 0:
switches.append(f'-v {a["variation_amount"]}')
if a['with_variations']: if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"])) formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
switches.append(f'-V {formatted_variations}') switches.append(f'-V {formatted_variations}')
@ -618,18 +624,24 @@ def metadata_dumps(opt,
postprocessing=postprocessing postprocessing=postprocessing
) )
# TODO: This is just a hack until postprocessing pipeline work completed # 'postprocessing' is either null or an array of postprocessing metadatal
image_dict['postprocessing'] = [] if postprocessing:
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0: # TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)') image_dict['postprocessing'] = []
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)') if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
else:
image_dict['postprocessing'] = None
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','step_number','width','height','extra','strength'] 'cfg_scale','step_number','width','height','extra','strength']
rfc_dict ={} rfc_dict ={}
for item in image_dict.items(): for item in image_dict.items():
key,value = item key,value = item
if key in rfc266_img_fields: if key in rfc266_img_fields:
@ -637,25 +649,24 @@ def metadata_dumps(opt,
# semantic drift # semantic drift
rfc_dict['sampler'] = image_dict.get('sampler_name',None) rfc_dict['sampler'] = image_dict.get('sampler_name',None)
# display weighted subprompts (liable to change) # display weighted subprompts (liable to change)
if opt.prompt: if opt.prompt:
subprompts = split_weighted_subprompts(opt.prompt) subprompts = split_weighted_subprompts(opt.prompt)
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts] subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
rfc_dict['prompt'] = subprompts rfc_dict['prompt'] = subprompts
# variations # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
if opt.with_variations: rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else []
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
rfc_dict['variations'] = variations
if opt.init_img: if opt.init_img:
rfc_dict['type'] = 'img2img' rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength') rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img) rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img)
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else: else:
rfc_dict['type'] = 'txt2img' rfc_dict['type'] = 'txt2img'
rfc_dict.pop('strength')
if len(seeds)==0 and opt.seed: if len(seeds)==0 and opt.seed:
seeds=[seed] seeds=[seed]