Initial user uploads implementation

This commit is contained in:
psychedelicious 2022-10-28 23:15:03 +11:00
parent 9d1594cbcc
commit 3a7b495167
24 changed files with 362 additions and 1151 deletions

View File

@ -174,9 +174,12 @@ class InvokeAIWebServer:
)
@socketio.on("requestLatestImages")
def handle_request_latest_images(latest_mtime):
def handle_request_latest_images(category, latest_mtime):
try:
paths = glob.glob(os.path.join(self.result_path, "*.png"))
base_path = (
self.result_path if category == "result" else self.init_image_path
)
paths = glob.glob(os.path.join(base_path, "*.png"))
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
@ -201,14 +204,13 @@ class InvokeAIWebServer:
"metadata": metadata["sd-metadata"],
"width": width,
"height": height,
"category": category,
}
)
socketio.emit(
"galleryImages",
{
"images": image_array,
},
{"images": image_array, "category": category},
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
@ -218,11 +220,15 @@ class InvokeAIWebServer:
print("\n")
@socketio.on("requestImages")
def handle_request_images(earliest_mtime=None):
def handle_request_images(category, earliest_mtime=None):
try:
page_size = 50
paths = glob.glob(os.path.join(self.result_path, "*.png"))
base_path = (
self.result_path if category == "result" else self.init_image_path
)
paths = glob.glob(os.path.join(base_path, "*.png"))
image_paths = sorted(
paths, key=lambda x: os.path.getmtime(x), reverse=True
@ -253,6 +259,7 @@ class InvokeAIWebServer:
"metadata": metadata["sd-metadata"],
"width": width,
"height": height,
"category": category,
}
)
@ -261,6 +268,7 @@ class InvokeAIWebServer:
{
"images": image_array,
"areMoreImagesAvailable": areMoreImagesAvailable,
"category": category,
},
)
except Exception as e:
@ -416,14 +424,17 @@ class InvokeAIWebServer:
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
def handle_delete_image(url, uuid):
def handle_delete_image(url, uuid, category):
try:
print(f'>> Delete requested "{url}"')
from send2trash import send2trash
path = self.get_image_path_from_url(url)
print(path)
send2trash(path)
socketio.emit("imageDeleted", {"url": url, "uuid": uuid})
socketio.emit(
"imageDeleted", {"url": url, "uuid": uuid, "category": category}
)
except Exception as e:
self.socketio.emit("error", {"message": (str(e))})
print("\n")
@ -439,11 +450,17 @@ class InvokeAIWebServer:
file_path = self.save_file_unique_uuid_name(
bytes=bytes, name=name, path=self.init_image_path
)
mtime = os.path.getmtime(file_path)
(width, height) = Image.open(file_path).size
print(file_path)
socketio.emit(
"initialImageUploaded",
{
"url": self.get_url_from_image_path(file_path),
"mtime": mtime,
"width": width,
"height": height,
"category": "user",
},
)
except Exception as e:

View File

@ -1,822 +0,0 @@
import mimetypes
import transformers
import json
import os
import traceback
import eventlet
import glob
import shlex
import math
import shutil
import sys
sys.path.append(".")
from argparse import ArgumentTypeError
from modules.create_cmd_parser import create_cmd_parser
parser = create_cmd_parser()
opt = parser.parse_args()
from flask_socketio import SocketIO
from flask import Flask, send_from_directory, url_for, jsonify
from pathlib import Path
from PIL import Image
from pytorch_lightning import logging
from threading import Event
from uuid import uuid4
from send2trash import send2trash
from ldm.generate import Generate
from ldm.invoke.restoration import Restoration
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.prompt_parser import split_weighted_subprompts
from modules.parameters import parameters_to_command
"""
USER CONFIG
"""
if opt.cors and "*" in opt.cors:
raise ArgumentTypeError('"*" is not an allowed CORS origin')
output_dir = "outputs/" # Base output directory for images
host = opt.host # Web & socket.io host
port = opt.port # Web & socket.io port
verbose = opt.verbose # enables copious socket.io logging
precision = opt.precision
free_gpu_mem = opt.free_gpu_mem
embedding_path = opt.embedding_path
additional_allowed_origins = (
opt.cors if opt.cors else []
) # additional CORS allowed origins
model = "stable-diffusion-1.4"
"""
END USER CONFIG
"""
print("* Initializing, be patient...\n")
"""
SERVER SETUP
"""
# fix missing mimetypes on windows due to registry wonkiness
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
app = Flask(__name__, static_url_path="", static_folder="../frontend/dist/")
app.config["OUTPUTS_FOLDER"] = "../outputs"
@app.route("/outputs/<path:filename>")
def outputs(filename):
return send_from_directory(app.config["OUTPUTS_FOLDER"], filename)
@app.route("/", defaults={"path": ""})
def serve(path):
return send_from_directory(app.static_folder, "index.html")
logger = True if verbose else False
engineio_logger = True if verbose else False
# default 1,000,000, needs to be higher for socketio to accept larger images
max_http_buffer_size = 10000000
cors_allowed_origins = [f"http://{host}:{port}"] + additional_allowed_origins
socketio = SocketIO(
app,
logger=logger,
engineio_logger=engineio_logger,
max_http_buffer_size=max_http_buffer_size,
cors_allowed_origins=cors_allowed_origins,
ping_interval=(50, 50),
ping_timeout=60,
)
"""
END SERVER SETUP
"""
"""
APP SETUP
"""
class CanceledException(Exception):
pass
try:
gfpgan, codeformer, esrgan = None, None, None
from ldm.invoke.restoration.base import Restoration
restoration = Restoration()
gfpgan, codeformer = restoration.load_face_restore_models()
esrgan = restoration.load_esrgan()
# coreformer.process(self, image, strength, device, seed=None, fidelity=0.75)
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
canceled = Event()
# reduce logging outputs to error
transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
# Initialize and load model
generate = Generate(
model,
precision=precision,
embedding_path=embedding_path,
)
generate.free_gpu_mem = free_gpu_mem
generate.load_model()
# location for "finished" images
result_path = os.path.join(output_dir, "img-samples/")
# temporary path for intermediates
intermediate_path = os.path.join(result_path, "intermediates/")
# path for user-uploaded init images and masks
init_image_path = os.path.join(result_path, "init-images/")
mask_image_path = os.path.join(result_path, "mask-images/")
# txt log
log_path = os.path.join(result_path, "invoke_log.txt")
# make all output paths
[
os.makedirs(path, exist_ok=True)
for path in [result_path, intermediate_path, init_image_path, mask_image_path]
]
"""
END APP SETUP
"""
"""
SOCKET.IO LISTENERS
"""
@socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(f">> System config requested")
config = get_system_config()
socketio.emit("systemConfig", config)
@socketio.on("requestImages")
def handle_request_images(page=1, offset=0, last_mtime=None):
chunk_size = 50
if last_mtime:
print(f">> Latest images requested")
else:
print(
f">> Page {page} of images requested (page size {chunk_size} offset {offset})"
)
paths = glob.glob(os.path.join(result_path, "*.png"))
sorted_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True)
if last_mtime:
image_paths = filter(lambda x: os.path.getmtime(x) > last_mtime, sorted_paths)
else:
image_paths = sorted_paths[
slice(chunk_size * (page - 1) + offset, chunk_size * page + offset)
]
page = page + 1
image_array = []
for path in image_paths:
metadata = retrieve_metadata(path)
image_array.append(
{
"url": path,
"mtime": os.path.getmtime(path),
"metadata": metadata["sd-metadata"],
}
)
socketio.emit(
"galleryImages",
{
"images": image_array,
"nextPage": page,
"offset": offset,
"onlyNewImages": True if last_mtime else False,
},
)
@socketio.on("generateImage")
def handle_generate_image_event(
generation_parameters, esrgan_parameters, gfpgan_parameters
):
print(
f">> Image generation requested: {generation_parameters}\nESRGAN parameters: {esrgan_parameters}\nGFPGAN parameters: {gfpgan_parameters}"
)
generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters)
@socketio.on("runESRGAN")
def handle_run_esrgan_event(original_image, esrgan_parameters):
print(
f'>> ESRGAN upscale requested for "{original_image["url"]}": {esrgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Upscaling"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = esrgan.process(
image=image,
upsampler_scale=esrgan_parameters["upscale"][0],
strength=esrgan_parameters["upscale"][1],
seed=seed,
)
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
esrgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=esrgan_parameters,
original_image_path=original_image["url"],
type="esrgan",
)
command = parameters_to_command(esrgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="esrgan")
write_log_message(f'[Upscaled] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
"esrganResult",
{
"url": os.path.relpath(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
},
)
@socketio.on("runGFPGAN")
def handle_run_gfpgan_event(original_image, gfpgan_parameters):
print(
f'>> GFPGAN face fix requested for "{original_image["url"]}": {gfpgan_parameters}'
)
progress = {
"currentStep": 1,
"totalSteps": 1,
"currentIteration": 1,
"totalIterations": 1,
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = Image.open(original_image["url"])
seed = (
original_image["metadata"]["seed"]
if "seed" in original_image["metadata"]
else "unknown_seed"
)
progress["currentStatus"] = "Fixing faces"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = gfpgan.process(
image=image, strength=gfpgan_parameters["facetool_strength"], seed=seed
)
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
gfpgan_parameters["seed"] = seed
metadata = parameters_to_post_processed_image_metadata(
parameters=gfpgan_parameters,
original_image_path=original_image["url"],
type="gfpgan",
)
command = parameters_to_command(gfpgan_parameters)
path = save_image(image, command, metadata, result_path, postprocessing="gfpgan")
write_log_message(f'[Fixed faces] "{original_image["url"]}" > "{path}": {command}')
progress["currentStatus"] = "Finished"
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
"gfpganResult",
{
"url": os.path.relpath(path),
"mtime": os.path.mtime(path),
"metadata": metadata,
},
)
@socketio.on("cancel")
def handle_cancel():
print(f">> Cancel processing requested")
canceled.set()
socketio.emit("processingCanceled")
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
def handle_delete_image(path, uuid):
print(f'>> Delete requested "{path}"')
send2trash(path)
socketio.emit("imageDeleted", {"url": path, "uuid": uuid})
# TODO: I think this needs a safety mechanism.
@socketio.on("uploadInitialImage")
def handle_upload_initial_image(bytes, name):
print(f'>> Init image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(init_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
socketio.emit("initialImageUploaded", {"url": file_path, "uuid": ""})
# TODO: I think this needs a safety mechanism.
@socketio.on("uploadMaskImage")
def handle_upload_mask_image(bytes, name):
print(f'>> Mask image upload requested "{name}"')
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
file_path = os.path.join(mask_image_path, name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
newFile = open(file_path, "wb")
newFile.write(bytes)
socketio.emit("maskImageUploaded", {"url": file_path, "uuid": ""})
"""
END SOCKET.IO LISTENERS
"""
"""
ADDITIONAL FUNCTIONS
"""
def get_system_config():
return {
"model": "stable diffusion",
"model_id": model,
"model_hash": generate.model_hash,
"app_id": APP_ID,
"app_version": APP_VERSION,
}
def parameters_to_post_processed_image_metadata(parameters, original_image_path, type):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
orig_hash = calculate_init_img_hash(original_image_path)
image = {"orig_path": original_image_path, "orig_hash": orig_hash}
if type == "esrgan":
image["type"] = "esrgan"
image["scale"] = parameters["upscale"][0]
image["strength"] = parameters["upscale"][1]
elif type == "gfpgan":
image["type"] = "gfpgan"
image["strength"] = parameters["facetool_strength"]
else:
raise TypeError(f"Invalid type: {type}")
metadata["image"] = image
return metadata
def parameters_to_generated_image_metadata(parameters):
# top-level metadata minus `image` or `images`
metadata = get_system_config()
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = [
"type",
"postprocessing",
"sampler",
"prompt",
"seed",
"variations",
"steps",
"cfg_scale",
"threshold",
"perlin",
"step_number",
"width",
"height",
"extra",
"seamless",
"hires_fix",
]
rfc_dict = {}
for item in parameters.items():
key, value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
postprocessing = []
# 'postprocessing' is either null or an
if "facetool_strength" in parameters:
postprocessing.append(
{"type": "gfpgan", "strength": float(parameters["facetool_strength"])}
)
if "upscale" in parameters:
postprocessing.append(
{
"type": "esrgan",
"scale": int(parameters["upscale"][0]),
"strength": float(parameters["upscale"][1]),
}
)
rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None
# semantic drift
rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"])
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts
# 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs
variations = []
if "with_variations" in parameters:
variations = [
{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]
]
rfc_dict["variations"] = variations
if "init_img" in parameters:
rfc_dict["type"] = "img2img"
rfc_dict["strength"] = parameters["strength"]
rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant
rfc_dict["orig_hash"] = calculate_init_img_hash(parameters["init_img"])
rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant
rfc_dict["sampler"] = "ddim" # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
if "init_mask" in parameters:
rfc_dict["mask_hash"] = calculate_init_img_hash(
parameters["init_mask"]
) # TODO: Noncompliant
rfc_dict["mask_image_path"] = parameters["init_mask"] # TODO: Noncompliant
else:
rfc_dict["type"] = "txt2img"
metadata["image"] = rfc_dict
return metadata
def make_unique_init_image_filename(name):
uuid = uuid4().hex
split = os.path.splitext(name)
name = f"{split[0]}.{uuid}{split[1]}"
return name
def write_log_message(message, log_path=log_path):
"""Logs the filename and parameters used to generate or process that image to log file"""
message = f"{message}\n"
with open(log_path, "a", encoding="utf-8") as file:
file.writelines(message)
def save_image(
image, command, metadata, output_dir, step_index=None, postprocessing=False
):
pngwriter = PngWriter(output_dir)
prefix = pngwriter.unique_prefix()
seed = "unknown_seed"
if "image" in metadata:
if "seed" in metadata["image"]:
seed = metadata["image"]["seed"]
filename = f"{prefix}.{seed}"
if step_index:
filename += f".{step_index}"
if postprocessing:
filename += f".postprocessed"
filename += ".png"
path = pngwriter.save_image_and_prompt_to_png(
image=image, dream_prompt=command, metadata=metadata, name=filename
)
return path
def calculate_real_steps(steps, strength, has_init_image):
return math.floor(strength * steps) if has_init_image else steps
def generate_images(generation_parameters, esrgan_parameters, gfpgan_parameters):
canceled.clear()
step_index = 1
prior_variations = (
generation_parameters["with_variations"]
if "with_variations" in generation_parameters
else []
)
"""
If a result image is used as an init image, and then deleted, we will want to be
able to use it as an init image in the future. Need to copy it.
If the init/mask image doesn't exist in the init_image_path/mask_image_path,
make a unique filename for it and copy it there.
"""
if "init_img" in generation_parameters:
filename = os.path.basename(generation_parameters["init_img"])
if not os.path.exists(os.path.join(init_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_img"] = new_path
if "init_mask" in generation_parameters:
filename = os.path.basename(generation_parameters["init_mask"])
if not os.path.exists(os.path.join(mask_image_path, filename)):
unique_filename = make_unique_init_image_filename(filename)
new_path = os.path.join(init_image_path, unique_filename)
shutil.copy(generation_parameters["init_img"], new_path)
generation_parameters["init_mask"] = new_path
totalSteps = calculate_real_steps(
steps=generation_parameters["steps"],
strength=generation_parameters["strength"]
if "strength" in generation_parameters
else None,
has_init_image="init_img" in generation_parameters,
)
progress = {
"currentStep": 1,
"totalSteps": totalSteps,
"currentIteration": 1,
"totalIterations": generation_parameters["iterations"],
"currentStatus": "Preparing",
"isProcessing": True,
"currentStatusHasSteps": False,
}
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_progress(sample, step):
if canceled.is_set():
raise CanceledException
nonlocal step_index
nonlocal generation_parameters
nonlocal progress
progress["currentStep"] = step + 1
progress["currentStatus"] = "Generating"
progress["currentStatusHasSteps"] = True
if (
generation_parameters["progress_images"]
and step % 5 == 0
and step < generation_parameters["steps"] - 1
):
image = generate.sample_to_image(sample)
metadata = parameters_to_generated_image_metadata(generation_parameters)
command = parameters_to_command(generation_parameters)
path = save_image(image, command, metadata, intermediate_path, step_index=step_index, postprocessing=False)
step_index += 1
socketio.emit(
"intermediateResult",
{
"url": os.path.relpath(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
},
)
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
def image_done(image, seed, first_seed):
nonlocal generation_parameters
nonlocal esrgan_parameters
nonlocal gfpgan_parameters
nonlocal progress
step_index = 1
nonlocal prior_variations
progress["currentStatus"] = "Generation complete"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
all_parameters = generation_parameters
postprocessing = False
if (
"variation_amount" in all_parameters
and all_parameters["variation_amount"] > 0
):
first_seed = first_seed or seed
this_variation = [[seed, all_parameters["variation_amount"]]]
all_parameters["with_variations"] = prior_variations + this_variation
all_parameters["seed"] = first_seed
elif ("with_variations" in all_parameters):
all_parameters["seed"] = first_seed
else:
all_parameters["seed"] = seed
if esrgan_parameters:
progress["currentStatus"] = "Upscaling"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = esrgan.process(
image=image,
upsampler_scale=esrgan_parameters["level"],
strength=esrgan_parameters["strength"],
seed=seed,
)
postprocessing = True
all_parameters["upscale"] = [
esrgan_parameters["level"],
esrgan_parameters["strength"],
]
if gfpgan_parameters:
progress["currentStatus"] = "Fixing faces"
progress["currentStatusHasSteps"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
image = gfpgan.process(
image=image, strength=gfpgan_parameters["strength"], seed=seed
)
postprocessing = True
all_parameters["facetool_strength"] = gfpgan_parameters["strength"]
progress["currentStatus"] = "Saving image"
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
metadata = parameters_to_generated_image_metadata(all_parameters)
command = parameters_to_command(all_parameters)
path = save_image(
image, command, metadata, result_path, postprocessing=postprocessing
)
print(f'>> Image generated: "{path}"')
write_log_message(f'[Generated] "{path}": {command}')
if progress["totalIterations"] > progress["currentIteration"]:
progress["currentStep"] = 1
progress["currentIteration"] += 1
progress["currentStatus"] = "Iteration finished"
progress["currentStatusHasSteps"] = False
else:
progress["currentStep"] = 0
progress["totalSteps"] = 0
progress["currentIteration"] = 0
progress["totalIterations"] = 0
progress["currentStatus"] = "Finished"
progress["isProcessing"] = False
socketio.emit("progressUpdate", progress)
eventlet.sleep(0)
socketio.emit(
"generationResult",
{
"url": os.path.relpath(path),
"mtime": os.path.getmtime(path),
"metadata": metadata,
},
)
eventlet.sleep(0)
try:
generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
image_callback=image_done,
)
except KeyboardInterrupt:
raise
except CanceledException:
pass
except Exception as e:
socketio.emit("error", {"message": (str(e))})
print("\n")
traceback.print_exc()
print("\n")
"""
END ADDITIONAL FUNCTIONS
"""
if __name__ == "__main__":
print(f">> Starting server at http://{host}:{port}")
socketio.run(app, host=host, port=port)

View File

@ -12,6 +12,8 @@
* 'gfpgan'.
*/
import { Category as GalleryCategory } from '../features/gallery/gallerySlice';
/**
* TODO:
* Once an image has been generated, if it is postprocessed again,
@ -105,14 +107,15 @@ export declare type Metadata = SystemConfig & {
image: GeneratedImageMetadata | PostProcessedImageMetadata;
};
// An Image has a UUID, url (path?) and Metadata.
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type Image = {
uuid: string;
url: string;
mtime: number;
metadata: Metadata;
metadata?: Metadata;
width: number;
height: number;
category: GalleryCategory;
};
// GalleryImages is an array of Image.
@ -167,13 +170,9 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = {
url: string;
mtime: number;
metadata: Metadata;
width: number;
height: number;
};
export declare type ImageResultResponse = Omit<Image, 'uuid'>;
export declare type ImageUploadResponse = Omit<Image, 'uuid' | 'metadata'>;
export declare type ErrorResponse = {
message: string;
@ -183,11 +182,13 @@ export declare type ErrorResponse = {
export declare type GalleryImagesResponse = {
images: Array<Omit<Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
export declare type ImageUrlAndUuidResponse = {
export declare type ImageDeletedResponse = {
uuid: string;
url: string;
category: GalleryCategory;
};
export declare type ImageUrlResponse = {

View File

@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { GalleryCategory } from '../../features/gallery/gallerySlice';
import { InvokeTabName } from '../../features/tabs/InvokeTabs';
import * as InvokeAI from '../invokeai';
@ -15,8 +16,8 @@ export const generateImage = createAction<InvokeTabName>(
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const requestImages = createAction<undefined>('socketio/requestImages');
export const requestNewImages = createAction<undefined>(
export const requestImages = createAction<GalleryCategory>('socketio/requestImages');
export const requestNewImages = createAction<GalleryCategory>(
'socketio/requestNewImages'
);
export const cancelProcessing = createAction<undefined>(

View File

@ -5,6 +5,11 @@ import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from '../../common/util/parameterTranslation';
import {
GalleryCategory,
GalleryState,
} from '../../features/gallery/gallerySlice';
import { OptionsState } from '../../features/options/optionsSlice';
import {
addLogEntry,
errorOccurred,
@ -108,7 +113,8 @@ const makeSocketIOEmitters = (
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { upscalingLevel, upscalingStrength } = getState().options;
const options: OptionsState = getState().options;
const { upscalingLevel, upscalingStrength } = options;
const esrganParameters = {
upscale: [upscalingLevel, upscalingStrength],
};
@ -128,8 +134,8 @@ const makeSocketIOEmitters = (
},
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
dispatch(setIsProcessing(true));
const { facetoolType, facetoolStrength, codeformerFidelity } =
getState().options;
const options: OptionsState = getState().options;
const { facetoolType, facetoolStrength, codeformerFidelity } = options;
const facetoolParameters: Record<string, any> = {
facetool_strength: facetoolStrength,
@ -156,16 +162,18 @@ const makeSocketIOEmitters = (
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
const { url, uuid } = imageToDelete;
socketio.emit('deleteImage', url, uuid);
const { url, uuid, category } = imageToDelete;
socketio.emit('deleteImage', url, uuid, category);
},
emitRequestImages: () => {
const { earliest_mtime } = getState().gallery;
socketio.emit('requestImages', earliest_mtime);
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: () => {
const { latest_mtime } = getState().gallery;
socketio.emit('requestLatestImages', latest_mtime);
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');

View File

@ -21,12 +21,14 @@ import {
addGalleryImages,
addImage,
clearIntermediateImage,
GalleryState,
removeImage,
setIntermediateImage,
} from '../../features/gallery/gallerySlice';
import {
setInitialImagePath,
clearInitialImage,
setInitialImage,
setMaskPath,
} from '../../features/options/optionsSlice';
import { requestImages, requestNewImages } from './actions';
@ -48,10 +50,18 @@ const makeSocketIOListeners = (
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus('Connected'));
if (getState().gallery.latest_mtime) {
dispatch(requestNewImages());
const gallery: GalleryState = getState().gallery;
if (gallery.categories.user.latest_mtime) {
dispatch(requestNewImages('user'));
} else {
dispatch(requestImages());
dispatch(requestImages('user'));
}
if (gallery.categories.result.latest_mtime) {
dispatch(requestNewImages('result'));
} else {
dispatch(requestImages('result'));
}
} catch (e) {
console.error(e);
@ -83,8 +93,11 @@ const makeSocketIOListeners = (
try {
dispatch(
addImage({
uuid: uuidv4(),
...data,
category: 'result',
image: {
uuid: uuidv4(),
...data,
},
})
);
dispatch(
@ -125,8 +138,11 @@ const makeSocketIOListeners = (
try {
dispatch(
addImage({
uuid: uuidv4(),
...data,
category: 'result',
image: {
uuid: uuidv4(),
...data,
},
})
);
@ -180,7 +196,7 @@ const makeSocketIOListeners = (
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
const { images, areMoreImagesAvailable } = data;
const { images, areMoreImagesAvailable, category } = data;
/**
* the logic here ideally would be in the reducer but we have a side effect:
@ -189,19 +205,18 @@ const makeSocketIOListeners = (
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI.Image => {
const { url, metadata, mtime, width, height } = image;
return {
uuid: uuidv4(),
url,
mtime,
metadata,
width,
height,
...image,
};
});
dispatch(
addGalleryImages({ images: preparedImages, areMoreImagesAvailable })
addGalleryImages({
images: preparedImages,
areMoreImagesAvailable,
category,
})
);
dispatch(
@ -220,7 +235,12 @@ const makeSocketIOListeners = (
const { intermediateImage } = getState().gallery;
if (intermediateImage) {
dispatch(addImage(intermediateImage));
dispatch(
addImage({
category: 'result',
image: intermediateImage,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
@ -241,14 +261,17 @@ const makeSocketIOListeners = (
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: InvokeAI.ImageUrlAndUuidResponse) => {
const { url, uuid } = data;
dispatch(removeImage(uuid));
onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
const { url, uuid, category } = data;
const { initialImagePath, maskPath } = getState().options;
// remove image from gallery
dispatch(removeImage(data));
if (initialImagePath === url) {
dispatch(setInitialImagePath(''));
// remove references to image in options
const { initialImage, maskPath } = getState().options;
if (initialImage?.url === url || initialImage === url) {
dispatch(clearInitialImage());
}
if (maskPath === url) {
@ -262,18 +285,25 @@ const makeSocketIOListeners = (
})
);
},
/**
* Callback to run when we receive a 'initialImageUploaded' event.
*/
onInitialImageUploaded: (data: InvokeAI.ImageUrlResponse) => {
const { url } = data;
dispatch(setInitialImagePath(url));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Initial image uploaded: ${url}`,
})
);
onInitialImageUploaded: (data: InvokeAI.ImageUploadResponse) => {
const image = {
uuid: uuidv4(),
...data,
};
try {
dispatch(addImage({ image, category: 'user' }));
dispatch(setInitialImage(image));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image uploaded: ${data.url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'maskImageUploaded' event.

View File

@ -100,13 +100,20 @@ export const socketioMiddleware = () => {
onProcessingCanceled();
});
socketio.on('imageDeleted', (data: InvokeAI.ImageUrlAndUuidResponse) => {
socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
onImageDeleted(data);
});
socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onInitialImageUploaded(data);
});
// socketio.on('initialImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
// onInitialImageUploaded(data);
// });
socketio.on(
'initialImageUploaded',
(data: InvokeAI.ImageUploadResponse) => {
onInitialImageUploaded(data);
}
);
socketio.on('maskImageUploaded', (data: InvokeAI.ImageUrlResponse) => {
onMaskImageUploaded(data);
@ -152,12 +159,12 @@ export const socketioMiddleware = () => {
}
case 'socketio/requestImages': {
emitRequestImages();
emitRequestImages(action.payload);
break;
}
case 'socketio/requestNewImages': {
emitRequestNewImages();
emitRequestNewImages(action.payload);
break;
}

View File

@ -23,7 +23,7 @@ export const useCheckParametersSelector = createSelector(
shouldGenerateVariations: options.shouldGenerateVariations,
seedWeights: options.seedWeights,
maskPath: options.maskPath,
initialImagePath: options.initialImagePath,
initialImage: options.initialImage,
seed: options.seed,
activeTabName: tabMap[options.activeTab],
// system
@ -49,7 +49,7 @@ const useCheckParameters = (): boolean => {
shouldGenerateVariations,
seedWeights,
maskPath,
initialImagePath,
initialImage,
seed,
activeTabName,
isProcessing,
@ -63,7 +63,7 @@ const useCheckParameters = (): boolean => {
return false;
}
if (activeTabName === 'img2img' && !initialImagePath) {
if (activeTabName === 'img2img' && !initialImage) {
return false;
}
@ -72,7 +72,7 @@ const useCheckParameters = (): boolean => {
}
// Cannot generate with a mask without img2img
if (maskPath && !initialImagePath) {
if (maskPath && !initialImage) {
return false;
}
@ -100,8 +100,8 @@ const useCheckParameters = (): boolean => {
}, [
prompt,
maskPath,
initialImagePath,
isProcessing,
initialImage,
isConnected,
shouldGenerateVariations,
seedWeights,

View File

@ -47,7 +47,7 @@ export const frontendToBackendParameters = (
seamless,
hiresFix,
img2imgStrength,
initialImagePath,
initialImage,
shouldFitToWidthHeight,
shouldGenerateVariations,
variationAmount,
@ -89,8 +89,9 @@ export const frontendToBackendParameters = (
}
// img2img exclusive parameters
if (generationMode === 'img2img') {
generationParameters.init_img = initialImagePath;
if (generationMode === 'img2img' && initialImage) {
generationParameters.init_img =
typeof initialImage === 'string' ? initialImage : initialImage.url;
generationParameters.strength = img2imgStrength;
generationParameters.fit = shouldFitToWidthHeight;
}

View File

@ -8,7 +8,7 @@ import { RootState } from '../../app/store';
import {
setActiveTab,
setAllParameters,
setInitialImagePath,
setInitialImage,
setSeed,
setShouldShowImageDetails,
} from '../options/optionsSlice';
@ -85,7 +85,7 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => {
useAppSelector(systemSelector);
const handleClickUseAsInitialImage = () => {
dispatch(setInitialImagePath(image.url));
dispatch(setInitialImage(image));
dispatch(setActiveTab(1));
};
@ -114,7 +114,8 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => {
);
const handleClickUseAllParameters = () =>
dispatch(setAllParameters(image.metadata));
image.metadata && dispatch(setAllParameters(image.metadata));
useHotkeys(
'a',
() => {
@ -139,9 +140,7 @@ const CurrentImageButtons = ({ image }: CurrentImageButtonsProps) => {
[image]
);
// Non-null assertion: this button is disabled if there is no seed.
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const handleClickUseSeed = () => dispatch(setSeed(image.metadata.image.seed));
const handleClickUseSeed = () => image.metadata && dispatch(setSeed(image.metadata.image.seed));
useHotkeys(
's',
() => {

View File

@ -10,11 +10,15 @@ import _ from 'lodash';
const imagesSelector = createSelector(
(state: RootState) => state.gallery,
(gallery: GalleryState) => {
const currentImageIndex = gallery.images.findIndex(
const { currentCategory } = gallery;
const tempImages = gallery.categories[currentCategory].images;
const currentImageIndex = tempImages.findIndex(
(i) => i.uuid === gallery?.currentImage?.uuid
);
const imagesLength = gallery.images.length;
const imagesLength = tempImages.length;
return {
currentCategory,
isOnFirstImage: currentImageIndex === 0,
isOnLastImage:
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
@ -35,7 +39,7 @@ export default function CurrentImagePreview(props: CurrentImagePreviewProps) {
const { imageToDisplay } = props;
const dispatch = useAppDispatch();
const { isOnFirstImage, isOnLastImage } = useAppSelector(imagesSelector);
const { isOnFirstImage, isOnLastImage, currentCategory } = useAppSelector(imagesSelector);
const shouldShowImageDetails = useAppSelector(
(state: RootState) => state.options.shouldShowImageDetails
@ -53,11 +57,11 @@ export default function CurrentImagePreview(props: CurrentImagePreviewProps) {
};
const handleClickPrevButton = () => {
dispatch(selectPrevImage());
dispatch(selectPrevImage(currentCategory));
};
const handleClickNextButton = () => {
dispatch(selectNextImage());
dispatch(selectNextImage(currentCategory));
};
return (

View File

@ -19,7 +19,7 @@ import {
setActiveTab,
setAllImageToImageParameters,
setAllTextToImageParameters,
setInitialImagePath,
setInitialImage,
setPrompt,
setSeed,
} from '../options/optionsSlice';
@ -58,7 +58,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
dispatch(setPrompt(image.metadata.image.prompt));
image.metadata && dispatch(setPrompt(image.metadata.image.prompt));
toast({
title: 'Prompt Set',
status: 'success',
@ -68,7 +68,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
dispatch(setSeed(image.metadata.image.seed));
image.metadata && dispatch(setSeed(image.metadata.image.seed));
toast({
title: 'Seed Set',
status: 'success',
@ -78,7 +78,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(setInitialImagePath(image.url));
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
@ -104,7 +104,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
dispatch(setAllTextToImageParameters(metadata));
metadata && dispatch(setAllTextToImageParameters(metadata));
toast({
title: 'Parameters Set',
status: 'success',

View File

@ -126,6 +126,22 @@
}
}
.image-gallery-category-btn-group {
width: 100% !important;
column-gap: 0 !important;
justify-content: stretch !important;
button {
flex-grow: 1;
&[data-selected='true'] {
background-color: var(--accent-color);
&:hover {
background-color: var(--accent-color-hover);
}
}
}
}
// from https://css-tricks.com/a-grid-of-logos-in-squares/
.image-gallery {
display: grid;

View File

@ -11,6 +11,7 @@ import IAIIconButton from '../../common/components/IAIIconButton';
import {
selectNextImage,
selectPrevImage,
setCurrentCategory,
setGalleryImageMinimumWidth,
setGalleryImageObjectFit,
setGalleryScrollPosition,
@ -20,7 +21,7 @@ import {
} from './gallerySlice';
import HoverableImage from './HoverableImage';
import { setShouldShowGallery } from '../gallery/gallerySlice';
import { Spacer, useToast } from '@chakra-ui/react';
import { ButtonGroup, Spacer, useToast } from '@chakra-ui/react';
import { CSSTransition } from 'react-transition-group';
import { Direction } from 're-resizable/lib/resizer';
import { imageGallerySelector } from './gallerySliceSelectors';
@ -36,8 +37,8 @@ export default function ImageGallery() {
const {
images,
currentCategory,
currentImageUuid,
areMoreImagesAvailable,
shouldPinGallery,
shouldShowGallery,
galleryScrollPosition,
@ -47,6 +48,7 @@ export default function ImageGallery() {
galleryImageObjectFit,
shouldHoldGalleryOpen,
shouldAutoSwitchToNewImages,
areMoreImagesAvailable,
} = useAppSelector(imageGallerySelector);
const [gallerySize, setGallerySize] = useState<Size>({
@ -128,7 +130,7 @@ export default function ImageGallery() {
};
const handleClickLoadMore = () => {
dispatch(requestImages());
dispatch(requestImages(currentCategory));
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -151,13 +153,21 @@ export default function ImageGallery() {
[shouldShowGallery]
);
useHotkeys('left', () => {
dispatch(selectPrevImage());
});
useHotkeys(
'left',
() => {
dispatch(selectPrevImage(currentCategory));
},
[currentCategory]
);
useHotkeys('right', () => {
dispatch(selectNextImage());
});
useHotkeys(
'right',
() => {
dispatch(selectNextImage(currentCategory));
},
[currentCategory]
);
useHotkeys(
'shift+p',
@ -317,7 +327,7 @@ export default function ImageGallery() {
) : null}
<IAIPopover
trigger="click"
trigger="hover"
hasArrow={activeTabName === 'inpainting' ? false : true}
// styleClass="image-gallery-settings-popover"
triggerComponent={
@ -331,6 +341,27 @@ export default function ImageGallery() {
}
>
<div className="image-gallery-settings-popover">
<div>
<ButtonGroup
size="sm"
isAttached
variant="solid"
className="image-gallery-category-btn-group"
>
<Button
data-selected={currentCategory === 'result'}
onClick={() => dispatch(setCurrentCategory('result'))}
>
Invocations
</Button>
<Button
data-selected={currentCategory === 'user'}
onClick={() => dispatch(setCurrentCategory('user'))}
>
Uploads
</Button>
</ButtonGroup>
</div>
<div>
<IAISlider
value={galleryImageMinimumWidth}

View File

@ -1,129 +0,0 @@
import {
Button,
Drawer,
DrawerBody,
DrawerCloseButton,
DrawerContent,
DrawerHeader,
useDisclosure,
} from '@chakra-ui/react';
import React from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { MdPhotoLibrary } from 'react-icons/md';
import { requestImages } from '../../app/socketio/actions';
import { RootState, useAppDispatch } from '../../app/store';
import { useAppSelector } from '../../app/store';
import { selectNextImage, selectPrevImage } from './gallerySlice';
import HoverableImage from './HoverableImage';
/**
* Simple image gallery.
*/
const ImageGalleryOld = () => {
const { images, currentImageUuid, areMoreImagesAvailable } = useAppSelector(
(state: RootState) => state.gallery
);
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
/**
* I don't like that this needs to rerender whenever the current image is changed.
* What if we have a large number of images? I suppose pagination (planned) will
* mitigate this issue.
*
* TODO: Refactor if performance complaints, or after migrating to new API which supports pagination.
*/
const handleClickLoadMore = () => {
dispatch(requestImages());
};
useHotkeys(
'g',
() => {
if (isOpen) {
onClose();
} else {
onOpen();
}
},
[isOpen]
);
useHotkeys(
'left',
() => {
dispatch(selectPrevImage());
},
[]
);
useHotkeys(
'right',
() => {
dispatch(selectNextImage());
},
[]
);
return (
<div className="image-gallery-area">
<Button
colorScheme="teal"
onClick={onOpen}
className="image-gallery-popup-btn"
>
<MdPhotoLibrary />
</Button>
<Drawer
isOpen={isOpen}
placement="right"
onClose={onClose}
autoFocus={false}
trapFocus={false}
closeOnOverlayClick={false}
>
<DrawerContent className="image-gallery-popup">
<div className="image-gallery-header">
<DrawerHeader>Your Invocations</DrawerHeader>
<DrawerCloseButton />
</div>
<DrawerBody className="image-gallery-body">
<div className="image-gallery-container">
{images.length ? (
<div className="image-gallery">
{images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
return (
<HoverableImage
key={uuid}
image={image}
isSelected={isSelected}
/>
);
})}
</div>
) : (
<div className="image-gallery-container-placeholder">
<MdPhotoLibrary />
<p>No Images In Gallery</p>
</div>
)}
<Button
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
className="image-gallery-load-more-btn"
>
{areMoreImagesAvailable ? 'Load More' : 'All Images Loaded'}
</Button>
</div>
</DrawerBody>
</DrawerContent>
</Drawer>
</div>
);
};
export default ImageGalleryOld;

View File

@ -20,7 +20,6 @@ import {
setHeight,
setHiresFix,
setImg2imgStrength,
setInitialImagePath,
setMaskPath,
setPrompt,
setSampler,
@ -32,6 +31,7 @@ import {
setUpscalingLevel,
setUpscalingStrength,
setWidth,
setInitialImage,
} from '../../options/optionsSlice';
import promptToString from '../../../common/util/promptToString';
import { seedWeightsToString } from '../../../common/util/seedWeightPairs';
@ -248,7 +248,7 @@ const ImageMetadataViewer = memo(
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImagePath(init_image_path))}
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)}
{mask_image_path && (

View File

@ -3,16 +3,27 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import _, { clamp } from 'lodash';
import * as InvokeAI from '../../app/invokeai';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI.Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI.Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
currentImage?: InvokeAI.Image;
currentImageUuid: string;
images: Array<InvokeAI.Image>;
intermediateImage?: InvokeAI.Image;
areMoreImagesAvailable: boolean;
latest_mtime?: number;
earliest_mtime?: number;
shouldPinGallery: boolean;
shouldShowGallery: boolean;
galleryScrollPosition: number;
@ -20,12 +31,15 @@ export interface GalleryState {
galleryImageObjectFit: GalleryImageObjectFitType;
shouldHoldGalleryOpen: boolean;
shouldAutoSwitchToNewImages: boolean;
categories: {
user: Gallery;
result: Gallery;
};
currentCategory: GalleryCategory;
}
const initialState: GalleryState = {
currentImageUuid: '',
images: [],
areMoreImagesAvailable: true,
shouldPinGallery: true,
shouldShowGallery: true,
galleryScrollPosition: 0,
@ -33,6 +47,21 @@ const initialState: GalleryState = {
galleryImageObjectFit: 'cover',
shouldHoldGalleryOpen: false,
shouldAutoSwitchToNewImages: true,
currentCategory: 'result',
categories: {
user: {
images: [],
latest_mtime: undefined,
earliest_mtime: undefined,
areMoreImagesAvailable: true,
},
result: {
images: [],
latest_mtime: undefined,
earliest_mtime: undefined,
areMoreImagesAvailable: true,
},
},
};
export const gallerySlice = createSlice({
@ -43,10 +72,15 @@ export const gallerySlice = createSlice({
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
removeImage: (state, action: PayloadAction<string>) => {
const uuid = action.payload;
removeImage: (
state,
action: PayloadAction<InvokeAI.ImageDeletedResponse>
) => {
const { uuid, category } = action.payload;
const newImages = state.images.filter((image) => image.uuid !== uuid);
const tempImages = state.categories[category as GalleryCategory].images;
const newImages = tempImages.filter((image) => image.uuid !== uuid);
if (uuid === state.currentImageUuid) {
/**
@ -58,7 +92,7 @@ export const gallerySlice = createSlice({
*
* Get the currently selected image's index.
*/
const imageToDeleteIndex = state.images.findIndex(
const imageToDeleteIndex = tempImages.findIndex(
(image) => image.uuid === uuid
);
@ -84,24 +118,35 @@ export const gallerySlice = createSlice({
: '';
}
state.images = newImages;
state.categories[category as GalleryCategory].images = newImages;
},
addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const newImage = action.payload;
addImage: (
state,
action: PayloadAction<{
image: InvokeAI.Image;
category: GalleryCategory;
}>
) => {
const { image: newImage, category } = action.payload;
const { uuid, url, mtime } = newImage;
const tempCategory = state.categories[category as GalleryCategory];
// Do not add duplicate images
if (state.images.find((i) => i.url === url && i.mtime === mtime)) {
if (tempCategory.images.find((i) => i.url === url && i.mtime === mtime)) {
return;
}
state.images.unshift(newImage);
tempCategory.images.unshift(newImage);
if (state.shouldAutoSwitchToNewImages) {
state.currentImageUuid = uuid;
state.currentImage = newImage;
if (category === 'result') {
state.currentCategory = 'result';
}
}
state.intermediateImage = undefined;
state.latest_mtime = mtime;
tempCategory.latest_mtime = mtime;
},
setIntermediateImage: (state, action: PayloadAction<InvokeAI.Image>) => {
state.intermediateImage = action.payload;
@ -109,49 +154,53 @@ export const gallerySlice = createSlice({
clearIntermediateImage: (state) => {
state.intermediateImage = undefined;
},
selectNextImage: (state) => {
const { images, currentImage } = state;
selectNextImage: (state, action: PayloadAction<GalleryCategory>) => {
const category = action.payload;
const { currentImage } = state;
const tempImages = state.categories[category].images;
if (currentImage) {
const currentImageIndex = images.findIndex(
const currentImageIndex = tempImages.findIndex(
(i) => i.uuid === currentImage.uuid
);
if (_.inRange(currentImageIndex, 0, images.length)) {
const newCurrentImage = images[currentImageIndex + 1];
if (_.inRange(currentImageIndex, 0, tempImages.length)) {
const newCurrentImage = tempImages[currentImageIndex + 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
}
},
selectPrevImage: (state) => {
const { images, currentImage } = state;
selectPrevImage: (state, action: PayloadAction<GalleryCategory>) => {
const category = action.payload;
const { currentImage } = state;
const tempImages = state.categories[category].images;
if (currentImage) {
const currentImageIndex = images.findIndex(
const currentImageIndex = tempImages.findIndex(
(i) => i.uuid === currentImage.uuid
);
if (_.inRange(currentImageIndex, 1, images.length + 1)) {
const newCurrentImage = images[currentImageIndex - 1];
if (_.inRange(currentImageIndex, 1, tempImages.length + 1)) {
const newCurrentImage = tempImages[currentImageIndex - 1];
state.currentImage = newCurrentImage;
state.currentImageUuid = newCurrentImage.uuid;
}
}
},
addGalleryImages: (
state,
action: PayloadAction<{
images: Array<InvokeAI.Image>;
areMoreImagesAvailable: boolean;
}>
) => {
const { images, areMoreImagesAvailable } = action.payload;
addGalleryImages: (state, action: PayloadAction<AddImagesPayload>) => {
const { images, areMoreImagesAvailable, category } = action.payload;
const tempImages = state.categories[category].images;
// const prevImages = category === 'user' ? state.userImages : state.resultImages
if (images.length > 0) {
// Filter images that already exist in the gallery
const newImages = images.filter(
(newImage) =>
!state.images.find(
!tempImages.find(
(i) => i.url === newImage.url && i.mtime === newImage.mtime
)
);
state.images = state.images
state.categories[category].images = tempImages
.concat(newImages)
.sort((a, b) => b.mtime - a.mtime);
@ -162,11 +211,14 @@ export const gallerySlice = createSlice({
}
// keep track of the timestamps of latest and earliest images received
state.latest_mtime = images[0].mtime;
state.earliest_mtime = images[images.length - 1].mtime;
state.categories[category].latest_mtime = images[0].mtime;
state.categories[category].earliest_mtime =
images[images.length - 1].mtime;
}
if (areMoreImagesAvailable !== undefined) {
state.areMoreImagesAvailable = areMoreImagesAvailable;
state.categories[category].areMoreImagesAvailable =
areMoreImagesAvailable;
}
},
setShouldPinGallery: (state, action: PayloadAction<boolean>) => {
@ -193,6 +245,9 @@ export const gallerySlice = createSlice({
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitchToNewImages = action.payload;
},
setCurrentCategory: (state, action: PayloadAction<GalleryCategory>) => {
state.currentCategory = action.payload;
},
},
});
@ -212,6 +267,7 @@ export const {
setGalleryImageObjectFit,
setShouldHoldGalleryOpen,
setShouldAutoSwitchToNewImages,
setCurrentCategory,
} = gallerySlice.actions;
export default gallerySlice.reducer;

View File

@ -8,24 +8,22 @@ export const imageGallerySelector = createSelector(
[(state: RootState) => state.gallery, (state: RootState) => state.options],
(gallery: GalleryState, options: OptionsState) => {
const {
images,
categories,
currentCategory,
currentImageUuid,
areMoreImagesAvailable,
shouldPinGallery,
shouldShowGallery,
galleryScrollPosition,
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldHoldGalleryOpen,
shouldAutoSwitchToNewImages
shouldAutoSwitchToNewImages,
} = gallery;
const { activeTab } = options;
return {
images,
currentImageUuid,
areMoreImagesAvailable,
shouldPinGallery,
shouldShowGallery,
galleryScrollPosition,
@ -34,7 +32,11 @@ export const imageGallerySelector = createSelector(
galleryGridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
activeTabName: tabMap[activeTab],
shouldHoldGalleryOpen,
shouldAutoSwitchToNewImages
shouldAutoSwitchToNewImages,
images: categories[currentCategory].images,
areMoreImagesAvailable:
categories[currentCategory].areMoreImagesAvailable,
currentCategory,
};
}
);

View File

@ -27,8 +27,7 @@ export interface OptionsState {
codeformerFidelity: number;
upscalingLevel: UpscalingLevel;
upscalingStrength: number;
shouldUseInitImage: boolean;
initialImagePath: string | null;
initialImage?: InvokeAI.Image | string; // can be an Image or url
maskPath: string;
seamless: boolean;
hiresFix: boolean;
@ -58,9 +57,7 @@ const initialOptionsState: OptionsState = {
seed: 0,
seamless: false,
hiresFix: false,
shouldUseInitImage: false,
img2imgStrength: 0.75,
initialImagePath: null,
maskPath: '',
shouldFitToWidthHeight: true,
shouldGenerateVariations: false,
@ -137,14 +134,6 @@ export const optionsSlice = createSlice({
setUpscalingStrength: (state, action: PayloadAction<number>) => {
state.upscalingStrength = action.payload;
},
setShouldUseInitImage: (state, action: PayloadAction<boolean>) => {
state.shouldUseInitImage = action.payload;
},
setInitialImagePath: (state, action: PayloadAction<string | null>) => {
const newInitialImagePath = action.payload;
state.shouldUseInitImage = newInitialImagePath ? true : false;
state.initialImagePath = newInitialImagePath;
},
setMaskPath: (state, action: PayloadAction<string>) => {
state.maskPath = action.payload;
},
@ -170,9 +159,6 @@ export const optionsSlice = createSlice({
if (key === 'seed') {
temp.shouldRandomizeSeed = false;
}
if (key === 'initialImagePath' && value === '') {
temp.shouldUseInitImage = false;
}
return temp;
},
setShouldGenerateVariations: (state, action: PayloadAction<boolean>) => {
@ -236,13 +222,10 @@ export const optionsSlice = createSlice({
action.payload.image;
if (type === 'img2img') {
if (init_image_path) state.initialImagePath = init_image_path;
if (init_image_path) state.initialImage = init_image_path;
if (mask_image_path) state.maskPath = mask_image_path;
if (strength) state.img2imgStrength = strength;
if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.shouldUseInitImage = true;
} else {
state.shouldUseInitImage = false;
}
},
setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
@ -267,13 +250,10 @@ export const optionsSlice = createSlice({
} = action.payload.image;
if (type === 'img2img') {
if (init_image_path) state.initialImagePath = init_image_path;
if (init_image_path) state.initialImage = init_image_path;
if (mask_image_path) state.maskPath = mask_image_path;
if (strength) state.img2imgStrength = strength;
if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.shouldUseInitImage = true;
} else {
state.shouldUseInitImage = false;
}
if (variations && variations.length > 0) {
@ -335,6 +315,15 @@ export const optionsSlice = createSlice({
setShowDualDisplay: (state, action: PayloadAction<boolean>) => {
state.showDualDisplay = action.payload;
},
setInitialImage: (
state,
action: PayloadAction<InvokeAI.Image | string>
) => {
state.initialImage = action.payload;
},
clearInitialImage: (state) => {
state.initialImage = undefined;
},
},
});
@ -357,8 +346,6 @@ export const {
setCodeformerFidelity,
setUpscalingLevel,
setUpscalingStrength,
setShouldUseInitImage,
setInitialImagePath,
setMaskPath,
resetSeed,
resetOptionsState,
@ -377,6 +364,8 @@ export const {
setAllTextToImageParameters,
setAllImageToImageParameters,
setShowDualDisplay,
setInitialImage,
clearInitialImage,
} = optionsSlice.actions;
export default optionsSlice.reducer;

View File

@ -1,4 +1,5 @@
.console {
width: 100vw;
display: flex;
flex-direction: column;
background: var(--console-bg-color);

View File

@ -10,8 +10,8 @@ import ImageMetadataViewer from '../../gallery/ImageMetaDataViewer/ImageMetadata
import InitImagePreview from './InitImagePreview';
export default function ImageToImageDisplay() {
const initialImagePath = useAppSelector(
(state: RootState) => state.options.initialImagePath
const initialImage = useAppSelector(
(state: RootState) => state.options.initialImage
);
const { currentImage, intermediateImage } = useAppSelector(
@ -33,7 +33,7 @@ export default function ImageToImageDisplay() {
: { gridAutoRows: 'auto' }
}
>
{initialImagePath ? (
{initialImage ? (
<>
{imageToDisplay ? (
<>

View File

@ -2,12 +2,10 @@ import { IconButton, Image, useToast } from '@chakra-ui/react';
import React, { SyntheticEvent } from 'react';
import { MdClear } from 'react-icons/md';
import { RootState, useAppDispatch, useAppSelector } from '../../../app/store';
import { setInitialImagePath } from '../../options/optionsSlice';
import { clearInitialImage } from '../../options/optionsSlice';
export default function InitImagePreview() {
const initialImagePath = useAppSelector(
(state: RootState) => state.options.initialImagePath
);
const { initialImage } = useAppSelector((state: RootState) => state.options);
const dispatch = useAppDispatch();
@ -15,7 +13,7 @@ export default function InitImagePreview() {
const handleClickResetInitialImage = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(null));
dispatch(clearInitialImage());
};
const alertMissingInitImage = () => {
@ -25,7 +23,7 @@ export default function InitImagePreview() {
status: 'error',
isClosable: true,
});
dispatch(setInitialImagePath(null));
dispatch(clearInitialImage());
};
return (
@ -33,18 +31,20 @@ export default function InitImagePreview() {
<div className="init-image-preview-header">
<h1>Initial Image</h1>
<IconButton
isDisabled={!initialImagePath}
isDisabled={!initialImage}
size={'sm'}
aria-label={'Reset Initial Image'}
onClick={handleClickResetInitialImage}
icon={<MdClear />}
/>
</div>
{initialImagePath && (
{initialImage && (
<div className="init-image-image">
<Image
fit={'contain'}
src={initialImagePath}
src={
typeof initialImage === 'string' ? initialImage : initialImage.url
}
rounded={'md'}
onError={alertMissingInitImage}
/>

View File

@ -3,14 +3,14 @@ import React from 'react';
import { RootState, useAppSelector } from '../../../app/store';
export default function InitialImageOverlay() {
const initialImagePath = useAppSelector(
(state: RootState) => state.options.initialImagePath
const initialImage = useAppSelector(
(state: RootState) => state.options.initialImage
);
return initialImagePath ? (
return initialImage ? (
<Image
fit={'contain'}
src={initialImagePath}
src={typeof initialImage === 'string' ? initialImage : initialImage.url}
rounded={'md'}
className={'checkerboard'}
/>

View File

@ -12,9 +12,8 @@ type InvokeWorkareaProps = {
const InvokeWorkarea = (props: InvokeWorkareaProps) => {
const { optionsPanel, className, children } = props;
const { shouldShowGallery, shouldHoldGalleryOpen } = useAppSelector(
(state: RootState) => state.gallery
);
const { shouldShowGallery, shouldHoldGalleryOpen, shouldPinGallery } =
useAppSelector((state: RootState) => state.gallery);
return (
<div
@ -27,7 +26,7 @@ const InvokeWorkarea = (props: InvokeWorkareaProps) => {
<div className="workarea-content">{children}</div>
<ImageGallery />
</div>
{!(shouldShowGallery || shouldHoldGalleryOpen) && (
{!(shouldShowGallery || (shouldHoldGalleryOpen && !shouldPinGallery)) && (
<ShowHideGalleryButton />
)}
</div>